feat(semantics): add quantitative semantics

This commit is contained in:
Anand Balakrishnan 2023-04-26 12:59:54 -07:00
parent d3b30deaa3
commit bfd5178982
No known key found for this signature in database
2 changed files with 65 additions and 146 deletions

View file

@ -83,7 +83,7 @@ impl Semantics for BooleanSemantics {
} }
Ok(arg) Ok(arg)
} }
BoolExpr::Until { lhs, rhs } => todo!(), BoolExpr::Until { lhs: _, rhs: _ } => todo!(),
} }
} }
} }

View file

@ -1,132 +1,20 @@
use std::iter::zip;
use argus_core::expr::BoolExpr; use argus_core::expr::BoolExpr;
use argus_core::prelude::*; use argus_core::prelude::*;
use argus_core::signals::traits::{BaseSignal, SignalAbs, SignalMinMax, SignalNumCast}; use argus_core::signals::traits::{SignalAbs, SignalMinMax};
use argus_core::signals::SignalNumCast;
use crate::eval::NumExprEval; use crate::eval::eval_num_expr;
use crate::{Semantics, Trace}; use crate::{Semantics, Trace};
macro_rules! num_signal_binop_impl { fn top_or_bot(sig: &Signal<bool>) -> Signal<f64> {
($lhs:ident, $rhs:ident, $op:ident, [$( $type:ident ), *]) => { match sig {
paste::paste!{ Signal::Empty => Signal::Empty,
{ Signal::Constant { value } => Signal::constant(*value).to_f64().unwrap(),
use argus_core::prelude::*; Signal::Sampled { values, time_points } => zip(time_points, values)
use argus_core::ArgusError;
use AnySignal::*;
match ($lhs, $rhs) {
$(
([<$type >](lhs), [< $type >](rhs)) => Ok(AnySignal::from($op(&lhs, &rhs))),
([<$type >](lhs), [< Const $type >](rhs)) => Ok(AnySignal::from($op(&lhs, &rhs))),
([<Const $type >](lhs), [< $type >](rhs)) => Ok(AnySignal::from($op(&lhs, &rhs))),
([<Const $type >](lhs), [< Const $type >](rhs)) => Ok(AnySignal::from($op(&lhs, &rhs))),
)*
_ => Err(ArgusError::InvalidOperation),
}
}
}
};
}
fn less_than<T, Sig1, Sig2, Ret>(lhs: &Sig1, rhs: &Sig2) -> Ret
where
Sig1: SignalNumCast + BaseSignal<Value = T>,
Sig2: SignalNumCast + BaseSignal<Value = T>,
Ret: BaseSignal<Value = f64>,
for<'a> &'a <Sig2 as SignalNumCast>::Output<f64>:
std::ops::Sub<&'a <Sig1 as SignalNumCast>::Output<f64>, Output = Ret>,
{
let lhs = lhs.to_f64().unwrap();
let rhs = rhs.to_f64().unwrap();
&rhs - &lhs
}
fn greater_than<T, Sig1, Sig2, Ret>(lhs: &Sig1, rhs: &Sig2) -> Ret
where
Sig1: SignalNumCast + BaseSignal<Value = T>,
Sig2: SignalNumCast + BaseSignal<Value = T>,
Ret: BaseSignal<Value = f64>,
for<'a> &'a <Sig1 as SignalNumCast>::Output<f64>:
std::ops::Sub<&'a <Sig2 as SignalNumCast>::Output<f64>, Output = Ret>,
{
let lhs = lhs.to_f64().unwrap();
let rhs = rhs.to_f64().unwrap();
&lhs - &rhs
}
fn equal_to<T, Sig1, Sig2, Ret>(lhs: &Sig1, rhs: &Sig2) -> Ret
where
Sig1: SignalNumCast + BaseSignal<Value = T>,
Sig2: SignalNumCast + BaseSignal<Value = T>,
Ret: BaseSignal<Value = f64> + SignalAbs,
for<'a> &'a <Sig1 as SignalNumCast>::Output<f64>:
std::ops::Sub<&'a <Sig2 as SignalNumCast>::Output<f64>, Output = Ret>,
{
let lhs = lhs.to_f64().unwrap();
let rhs = rhs.to_f64().unwrap();
(&lhs - &rhs).abs()
}
fn not_equal_to<T, Sig1, Sig2, Ret>(lhs: &Sig1, rhs: &Sig2) -> Ret
where
Sig1: SignalNumCast + BaseSignal<Value = T>,
Sig2: SignalNumCast + BaseSignal<Value = T>,
Ret: BaseSignal<Value = f64> + SignalAbs,
for<'a> &'a Ret: core::ops::Neg<Output = Ret>,
for<'a> &'a <Sig1 as SignalNumCast>::Output<f64>:
std::ops::Sub<&'a <Sig2 as SignalNumCast>::Output<f64>, Output = Ret>,
{
let lhs = lhs.to_f64().unwrap();
let rhs = rhs.to_f64().unwrap();
-&((&lhs - &rhs).abs())
}
macro_rules! signal_bool_op_impl {
// Unary bool opeartions
(! $signal:ident) => {{
use argus_core::prelude::*;
use AnySignal::*;
match $signal {
Float(sig) => Ok(AnySignal::from(-(&sig))),
ConstFloat(sig) => Ok(AnySignal::from(-(&sig))),
_ => unreachable!("no other signal is expected in quantitative semantics"),
}
}};
($op:ident, $lhs:ident, $rhs:ident) => {
paste::paste! {
{
use argus_core::prelude::*;
use AnySignal::*;
match ($lhs, $rhs) {
(Float(lhs), Float(rhs)) => AnySignal::from(lhs.$op(&rhs)),
(Float(lhs), ConstFloat(rhs)) => AnySignal::from(lhs.$op(&rhs)),
(ConstFloat(lhs), Float(rhs)) => AnySignal::from(lhs.$op(&rhs)),
(ConstFloat(lhs), ConstFloat(rhs)) => AnySignal::from(lhs.$op(&rhs)),
_ => panic!("mismatched argument types for {} operation", stringify!($op)),
}
}
}
};
}
fn bool_to_f64_sig(sig: &Signal<bool>) -> Signal<f64> {
sig.iter()
.map(|(&t, &v)| if v { (t, f64::INFINITY) } else { (t, f64::NEG_INFINITY) }) .map(|(&t, &v)| if v { (t, f64::INFINITY) } else { (t, f64::NEG_INFINITY) })
.collect() .collect(),
}
fn bool_to_f64_const_sig(sig: &ConstantSignal<bool>) -> ConstantSignal<f64> {
if sig.value {
ConstantSignal::new(f64::INFINITY)
} else {
ConstantSignal::new(f64::NEG_INFINITY)
}
}
fn top_or_bot_sig(val: bool) -> ConstantSignal<f64> {
if val {
ConstantSignal::new(f64::INFINITY)
} else {
ConstantSignal::new(f64::NEG_INFINITY)
} }
} }
@ -134,35 +22,31 @@ fn top_or_bot_sig(val: bool) -> ConstantSignal<f64> {
pub struct QuantitativeSemantics; pub struct QuantitativeSemantics;
impl Semantics for QuantitativeSemantics { impl Semantics for QuantitativeSemantics {
type Output = AnySignal; type Output = Signal<f64>;
type Context = (); type Context = ();
fn eval(expr: &BoolExpr, trace: &impl Trace, ctx: Self::Context) -> ArgusResult<Self::Output> { fn eval(expr: &BoolExpr, trace: &impl Trace, ctx: Self::Context) -> ArgusResult<Self::Output> {
match expr { let ret: Self::Output = match expr {
BoolExpr::BoolLit(val) => Ok(top_or_bot_sig(*val).into()), BoolExpr::BoolLit(val) => top_or_bot(&Signal::constant(*val)),
BoolExpr::BoolVar { name } => { BoolExpr::BoolVar { name } => {
let sig = trace.get(name.as_str()).ok_or(ArgusError::SignalNotPresent)?; let sig = trace.get::<bool>(name.as_str()).ok_or(ArgusError::SignalNotPresent)?;
match sig { top_or_bot(sig)
AnySignal::ConstBool(bool_sig) => Ok(bool_to_f64_const_sig(bool_sig).into()),
AnySignal::Bool(sig) => Ok(bool_to_f64_sig(sig).into()),
_ => Err(ArgusError::InvalidSignalType),
}
} }
BoolExpr::Cmp { op, lhs, rhs } => { BoolExpr::Cmp { op, lhs, rhs } => {
use argus_core::expr::Ordering::*; use argus_core::expr::Ordering::*;
let lhs = NumExprEval::eval(lhs, trace); let lhs = eval_num_expr::<f64>(lhs, trace)?;
let rhs = NumExprEval::eval(rhs, trace); let rhs = eval_num_expr::<f64>(rhs, trace)?;
match op { match op {
Eq => num_signal_binop_impl!(lhs, rhs, equal_to, [Int, UInt, Float]), Eq => -&((&lhs - &rhs).abs()),
NotEq => num_signal_binop_impl!(lhs, rhs, not_equal_to, [Int, UInt, Float]), NotEq => (&lhs - &rhs).abs(),
Less { strict: _ } => num_signal_binop_impl!(lhs, rhs, less_than, [Int, UInt, Float]), Less { strict: _ } => &rhs - &lhs,
Greater { strict: _ } => num_signal_binop_impl!(lhs, rhs, greater_than, [Int, UInt, Float]), Greater { strict: _ } => &lhs - &rhs,
} }
} }
BoolExpr::Not { arg } => { BoolExpr::Not { arg } => {
let arg = Self::eval(arg, trace, ctx)?; let arg = Self::eval(arg, trace, ctx)?;
signal_bool_op_impl!(!arg) -&arg
} }
BoolExpr::And { args } => { BoolExpr::And { args } => {
assert!(args.len() >= 2); assert!(args.len() >= 2);
@ -171,8 +55,8 @@ impl Semantics for QuantitativeSemantics {
.map(|arg| Self::eval(arg, trace, ctx)) .map(|arg| Self::eval(arg, trace, ctx))
.collect::<ArgusResult<Vec<_>>>()?; .collect::<ArgusResult<Vec<_>>>()?;
args.into_iter() args.into_iter()
.reduce(|lhs, rhs| signal_bool_op_impl!(min, lhs, rhs)) .reduce(|lhs, rhs| lhs.min(&rhs))
.ok_or(ArgusError::InvalidOperation) .ok_or(ArgusError::InvalidOperation)?
} }
BoolExpr::Or { args } => { BoolExpr::Or { args } => {
assert!(args.len() >= 2); assert!(args.len() >= 2);
@ -181,9 +65,44 @@ impl Semantics for QuantitativeSemantics {
.map(|arg| Self::eval(arg, trace, ctx)) .map(|arg| Self::eval(arg, trace, ctx))
.collect::<ArgusResult<Vec<_>>>()?; .collect::<ArgusResult<Vec<_>>>()?;
args.into_iter() args.into_iter()
.reduce(|lhs, rhs| signal_bool_op_impl!(max, lhs, rhs)) .reduce(|lhs, rhs| lhs.max(&rhs))
.ok_or(ArgusError::InvalidOperation) .ok_or(ArgusError::InvalidOperation)?
}
BoolExpr::Next { arg: _ } => todo!(),
BoolExpr::Always { arg } => {
let mut arg = Self::eval(arg, trace, ctx)?;
match &mut arg {
// if signal is empty or constant, return the signal itself.
// This works because if a signal is positive everywhere, then it must
// "always be positive" (and vice versa).
Signal::Empty | Signal::Constant { value: _ } => (),
Signal::Sampled { values, time_points } => {
// Compute the min in a expanding window fashion from the back
for i in (0..(time_points.len() - 1)).rev() {
values[i] = values[i].min(values[i + 1]);
} }
} }
} }
arg
}
BoolExpr::Eventually { arg } => {
let mut arg = Self::eval(arg, trace, ctx)?;
match &mut arg {
// if signal is empty or constant, return the signal itself.
// This works because if a signal is positive somewhere, then it must
// "eventually be positive" (and vice versa).
Signal::Empty | Signal::Constant { value: _ } => (),
Signal::Sampled { values, time_points } => {
// Compute the max in a expanding window fashion from the back
for i in (0..(time_points.len() - 1)).rev() {
values[i] = values[i].max(values[i + 1]);
}
}
}
arg
}
BoolExpr::Until { lhs: _, rhs: _ } => todo!(),
};
Ok(ret)
}
} }