feat!: make interpolation method explicit

All methods that need to perform interpolation of some sort need an
explicit interpolation method. In Rust, this manifests as a generic
parameter, while in Python, this is a string parameter.
This commit is contained in:
Anand Balakrishnan 2023-10-04 14:42:51 -07:00
parent e2cff9449e
commit 50d5a0a78a
8 changed files with 221 additions and 296 deletions

View file

@ -4,7 +4,7 @@ use std::time::Duration;
use argus_core::expr::*;
use argus_core::prelude::*;
use argus_core::signals::interpolation::Linear;
use argus_core::signals::SignalPartialOrd;
use argus_core::signals::{InterpolationMethod, SignalPartialOrd};
use crate::semantics::QuantitativeSemantics;
use crate::traits::Trace;
@ -13,7 +13,11 @@ use crate::utils::lemire_minmax::MonoWedge;
pub struct BooleanSemantics;
impl BooleanSemantics {
pub fn eval(expr: &BoolExpr, trace: &impl Trace) -> ArgusResult<Signal<bool>> {
pub fn eval<BoolI, NumI>(expr: &BoolExpr, trace: &impl Trace) -> ArgusResult<Signal<bool>>
where
BoolI: InterpolationMethod<bool>,
NumI: InterpolationMethod<f64>,
{
let ret = match expr {
BoolExpr::BoolLit(val) => Signal::constant(val.0),
BoolExpr::BoolVar(BoolVar { name }) => trace
@ -22,8 +26,8 @@ impl BooleanSemantics {
.clone(),
BoolExpr::Cmp(Cmp { op, lhs, rhs }) => {
use argus_core::expr::Ordering::*;
let lhs = QuantitativeSemantics::eval_num_expr::<f64>(lhs, trace)?;
let rhs = QuantitativeSemantics::eval_num_expr::<f64>(rhs, trace)?;
let lhs = QuantitativeSemantics::eval_num_expr::<f64, NumI>(lhs, trace)?;
let rhs = QuantitativeSemantics::eval_num_expr::<f64, NumI>(rhs, trace)?;
match op {
Eq => lhs.signal_eq(&rhs).unwrap(),
@ -35,46 +39,48 @@ impl BooleanSemantics {
}
}
BoolExpr::Not(Not { arg }) => {
let arg = Self::eval(arg, trace)?;
let arg = Self::eval::<BoolI, NumI>(arg, trace)?;
!&arg
}
BoolExpr::And(And { args }) => {
assert!(args.len() >= 2);
args.iter()
.map(|arg| Self::eval(arg, trace))
.try_fold(Signal::const_true(), |acc, item| {
args.iter().map(|arg| Self::eval::<BoolI, NumI>(arg, trace)).try_fold(
Signal::const_true(),
|acc, item| {
let item = item?;
Ok(acc.and(&item))
})?
Ok(acc.and::<BoolI>(&item))
},
)?
}
BoolExpr::Or(Or { args }) => {
assert!(args.len() >= 2);
args.iter()
.map(|arg| Self::eval(arg, trace))
.try_fold(Signal::const_true(), |acc, item| {
args.iter().map(|arg| Self::eval::<BoolI, NumI>(arg, trace)).try_fold(
Signal::const_true(),
|acc, item| {
let item = item?;
Ok(acc.or(&item))
})?
Ok(acc.or::<BoolI>(&item))
},
)?
}
BoolExpr::Next(Next { arg }) => {
let arg = Self::eval(arg, trace)?;
let arg = Self::eval::<BoolI, NumI>(arg, trace)?;
compute_next(arg)?
}
BoolExpr::Oracle(Oracle { steps, arg }) => {
let arg = Self::eval(arg, trace)?;
let arg = Self::eval::<BoolI, NumI>(arg, trace)?;
compute_oracle(arg, *steps)?
}
BoolExpr::Always(Always { arg, interval }) => {
let arg = Self::eval(arg, trace)?;
let arg = Self::eval::<BoolI, NumI>(arg, trace)?;
compute_always(arg, interval)?
}
BoolExpr::Eventually(Eventually { arg, interval }) => {
let arg = Self::eval(arg, trace)?;
let arg = Self::eval::<BoolI, NumI>(arg, trace)?;
compute_eventually(arg, interval)?
}
BoolExpr::Until(Until { lhs, rhs, interval }) => {
let lhs = Self::eval(lhs, trace)?;
let rhs = Self::eval(rhs, trace)?;
let lhs = Self::eval::<BoolI, NumI>(lhs, trace)?;
let rhs = Self::eval::<BoolI, NumI>(rhs, trace)?;
compute_until(lhs, rhs, interval)?
}
};
@ -389,7 +395,7 @@ mod tests {
let trace = MyTrace { signals };
let rob = BooleanSemantics::eval(&spec, &trace).unwrap();
let rob = BooleanSemantics::eval::<Linear, Linear>(&spec, &trace).unwrap();
let expected = Signal::from_iter(vec![
(Duration::from_secs_f64(0.0), false),
(Duration::from_secs_f64(0.7), false),
@ -421,7 +427,7 @@ mod tests {
)]);
let trace = MyTrace { signals };
let rob = BooleanSemantics::eval(&spec, &trace).unwrap();
let rob = BooleanSemantics::eval::<Linear, Linear>(&spec, &trace).unwrap();
let Signal::Sampled { values, time_points: _ } = rob else {
panic!("boolean semantics should remain sampled");
@ -441,7 +447,7 @@ mod tests {
)]);
let trace = MyTrace { signals };
let rob = BooleanSemantics::eval(&spec, &trace).unwrap();
let rob = BooleanSemantics::eval::<Linear, Linear>(&spec, &trace).unwrap();
println!("{:#?}", rob);
let Signal::Sampled { values, time_points: _ } = rob else {

View file

@ -4,7 +4,7 @@ use std::time::Duration;
use argus_core::expr::*;
use argus_core::prelude::*;
use argus_core::signals::interpolation::Linear;
use argus_core::signals::SignalAbs;
use argus_core::signals::{InterpolationMethod, SignalAbs};
use num_traits::{Num, NumCast};
use crate::traits::Trace;
@ -13,7 +13,7 @@ use crate::utils::lemire_minmax::MonoWedge;
pub struct QuantitativeSemantics;
impl QuantitativeSemantics {
pub fn eval(expr: &BoolExpr, trace: &impl Trace) -> ArgusResult<Signal<f64>> {
pub fn eval<I: InterpolationMethod<f64>>(expr: &BoolExpr, trace: &impl Trace) -> ArgusResult<Signal<f64>> {
let ret = match expr {
BoolExpr::BoolLit(val) => top_or_bot(&Signal::constant(val.0)),
BoolExpr::BoolVar(BoolVar { name }) => trace
@ -22,23 +22,20 @@ impl QuantitativeSemantics {
.map(top_or_bot)?,
BoolExpr::Cmp(Cmp { op, lhs, rhs }) => {
use argus_core::expr::Ordering::*;
let lhs = Self::eval_num_expr::<f64>(lhs, trace)?;
let rhs = Self::eval_num_expr::<f64>(rhs, trace)?;
let lhs = Self::eval_num_expr::<f64, I>(lhs, trace)?;
let rhs = Self::eval_num_expr::<f64, I>(rhs, trace)?;
match op {
Eq => -&((&lhs - &rhs).abs()),
NotEq => (&lhs - &rhs).abs(),
Less { strict: _ } => &rhs - &lhs,
Greater { strict: _ } => &lhs - &rhs,
Eq => lhs.abs_diff::<_, I>(&rhs).negate(),
NotEq => lhs.abs_diff::<_, I>(&rhs).negate(),
Less { strict: _ } => rhs.sub::<_, I>(&lhs),
Greater { strict: _ } => lhs.sub::<_, I>(&rhs),
}
}
BoolExpr::Not(Not { arg }) => {
let arg = Self::eval(arg, trace)?;
-&arg
}
BoolExpr::Not(Not { arg }) => Self::eval::<I>(arg, trace)?.negate(),
BoolExpr::And(And { args }) => {
assert!(args.len() >= 2);
args.iter().map(|arg| Self::eval(arg, trace)).try_fold(
args.iter().map(|arg| Self::eval::<I>(arg, trace)).try_fold(
Signal::constant(f64::INFINITY),
|acc, item| {
let item = item?;
@ -48,7 +45,7 @@ impl QuantitativeSemantics {
}
BoolExpr::Or(Or { args }) => {
assert!(args.len() >= 2);
args.iter().map(|arg| Self::eval(arg, trace)).try_fold(
args.iter().map(|arg| Self::eval::<I>(arg, trace)).try_fold(
Signal::constant(f64::NEG_INFINITY),
|acc, item| {
let item = item?;
@ -57,39 +54,40 @@ impl QuantitativeSemantics {
)?
}
BoolExpr::Next(Next { arg }) => {
let arg = Self::eval(arg, trace)?;
let arg = Self::eval::<I>(arg, trace)?;
compute_next(arg)?
}
BoolExpr::Oracle(Oracle { steps, arg }) => {
let arg = Self::eval(arg, trace)?;
let arg = Self::eval::<I>(arg, trace)?;
compute_oracle(arg, *steps)?
}
BoolExpr::Always(Always { arg, interval }) => {
let arg = Self::eval(arg, trace)?;
let arg = Self::eval::<I>(arg, trace)?;
compute_always(arg, interval)?
}
BoolExpr::Eventually(Eventually { arg, interval }) => {
let arg = Self::eval(arg, trace)?;
let arg = Self::eval::<I>(arg, trace)?;
compute_eventually(arg, interval)?
}
BoolExpr::Until(Until { lhs, rhs, interval }) => {
let lhs = Self::eval(lhs, trace)?;
let rhs = Self::eval(rhs, trace)?;
let lhs = Self::eval::<I>(lhs, trace)?;
let rhs = Self::eval::<I>(rhs, trace)?;
compute_until(lhs, rhs, interval)?
}
};
Ok(ret)
}
pub fn eval_num_expr<T>(root: &NumExpr, trace: &impl Trace) -> ArgusResult<Signal<T>>
pub fn eval_num_expr<T, I>(root: &NumExpr, trace: &impl Trace) -> ArgusResult<Signal<T>>
where
T: Num + NumCast,
for<'a> &'a Signal<T>: std::ops::Neg<Output = Signal<T>>,
for<'a> &'a Signal<T>: std::ops::Add<&'a Signal<T>, Output = Signal<T>>,
for<'a> &'a Signal<T>: std::ops::Sub<&'a Signal<T>, Output = Signal<T>>,
for<'a> &'a Signal<T>: std::ops::Mul<&'a Signal<T>, Output = Signal<T>>,
for<'a> &'a Signal<T>: std::ops::Div<&'a Signal<T>, Output = Signal<T>>,
T: Num + NumCast + Clone,
for<'a> &'a T: core::ops::Neg<Output = T>,
for<'a> &'a T: core::ops::Add<&'a T, Output = T>,
for<'a> &'a T: core::ops::Sub<&'a T, Output = T>,
for<'a> &'a T: core::ops::Mul<&'a T, Output = T>,
for<'a> &'a T: core::ops::Div<&'a T, Output = T>,
Signal<T>: SignalAbs,
I: InterpolationMethod<T>,
{
match root {
NumExpr::IntLit(val) => Signal::constant(val.0).num_cast(),
@ -98,35 +96,35 @@ impl QuantitativeSemantics {
NumExpr::IntVar(IntVar { name }) => trace.get::<i64>(name.as_str()).unwrap().num_cast(),
NumExpr::UIntVar(UIntVar { name }) => trace.get::<u64>(name.as_str()).unwrap().num_cast(),
NumExpr::FloatVar(FloatVar { name }) => trace.get::<f64>(name.as_str()).unwrap().num_cast(),
NumExpr::Neg(Neg { arg }) => Self::eval_num_expr::<T>(arg, trace).map(|sig| -&sig),
NumExpr::Neg(Neg { arg }) => Self::eval_num_expr::<T, I>(arg, trace).map(|sig| sig.negate()),
NumExpr::Add(Add { args }) => {
let mut ret: Signal<T> = Signal::<T>::zero();
for arg in args.iter() {
let arg = Self::eval_num_expr::<T>(arg, trace)?;
ret = &ret + &arg;
let arg = Self::eval_num_expr::<T, I>(arg, trace)?;
ret = ret.add::<_, I>(&arg);
}
Ok(ret)
}
NumExpr::Sub(Sub { lhs, rhs }) => {
let lhs = Self::eval_num_expr::<T>(lhs, trace)?;
let rhs = Self::eval_num_expr::<T>(rhs, trace)?;
Ok(&lhs - &rhs)
let lhs = Self::eval_num_expr::<T, I>(lhs, trace)?;
let rhs = Self::eval_num_expr::<T, I>(rhs, trace)?;
Ok(lhs.sub::<_, I>(&rhs))
}
NumExpr::Mul(Mul { args }) => {
let mut ret: Signal<T> = Signal::<T>::one();
for arg in args.iter() {
let arg = Self::eval_num_expr::<T>(arg, trace)?;
ret = &ret * &arg;
let arg = Self::eval_num_expr::<T, I>(arg, trace)?;
ret = ret.mul::<_, I>(&arg);
}
Ok(ret)
}
NumExpr::Div(Div { dividend, divisor }) => {
let dividend = Self::eval_num_expr::<T>(dividend, trace)?;
let divisor = Self::eval_num_expr::<T>(divisor, trace)?;
Ok(&dividend / &divisor)
let dividend = Self::eval_num_expr::<T, I>(dividend, trace)?;
let divisor = Self::eval_num_expr::<T, I>(divisor, trace)?;
Ok(dividend.div::<_, I>(&divisor))
}
NumExpr::Abs(Abs { arg }) => {
let arg = Self::eval_num_expr::<T>(arg, trace)?;
let arg = Self::eval_num_expr::<T, I>(arg, trace)?;
Ok(arg.abs())
}
}
@ -201,9 +199,9 @@ fn compute_always(signal: Signal<f64>, interval: &Interval) -> ArgusResult<Signa
/// Compute timed always for the interval `[a, b]` (or, if `b` is `None`, `[a, ..]`.
fn compute_timed_always(signal: Signal<f64>, a: Duration, b: Option<Duration>) -> ArgusResult<Signal<f64>> {
let z1 = -signal;
let z1 = signal.negate();
let z2 = compute_timed_eventually(z1, a, b)?;
Ok(-z2)
Ok(z2.negate())
}
/// Compute untimed always
@ -418,6 +416,7 @@ mod tests {
use std::time::Duration;
use argus_core::expr::ExprBuilder;
use argus_core::signals::interpolation::Constant;
use argus_core::signals::AnySignal;
use itertools::assert_equal;
@ -465,7 +464,7 @@ mod tests {
let spec = expr_builder.float_const(5.0);
let trace = MyTrace::default();
let robustness = QuantitativeSemantics::eval_num_expr::<f64>(&spec, &trace).unwrap();
let robustness = QuantitativeSemantics::eval_num_expr::<f64, Linear>(&spec, &trace).unwrap();
assert!(matches!(robustness, Signal::Constant { value } if value == 5.0));
}
@ -502,7 +501,7 @@ mod tests {
let trace = MyTrace { signals };
let rob = QuantitativeSemantics::eval_num_expr::<f64>(&spec, &trace).unwrap();
let rob = QuantitativeSemantics::eval_num_expr::<f64, Linear>(&spec, &trace).unwrap();
let expected = Signal::from_iter(vec![
(Duration::from_secs_f64(0.0), 1.3 + 2.5),
(Duration::from_secs_f64(0.7), 3.0 + 4.0),
@ -536,7 +535,7 @@ mod tests {
let trace = MyTrace { signals };
let rob = QuantitativeSemantics::eval_num_expr::<f64>(&spec, &trace).unwrap();
let rob = QuantitativeSemantics::eval_num_expr::<f64, Linear>(&spec, &trace).unwrap();
let expected = Signal::from_iter(vec![
(Duration::from_secs_f64(0.0), 1.3 + 2.5),
(Duration::from_secs_f64(0.7), 3.0 + 4.0),
@ -567,7 +566,7 @@ mod tests {
let trace = MyTrace { signals };
let rob = QuantitativeSemantics::eval(&spec, &trace).unwrap();
let rob = QuantitativeSemantics::eval::<Linear>(&spec, &trace).unwrap();
let expected = Signal::from_iter(vec![
(Duration::from_secs_f64(0.0), 0.0 - 1.3),
(Duration::from_secs_f64(0.7), 0.0 - 3.0),
@ -599,7 +598,7 @@ mod tests {
)]);
let trace = MyTrace { signals };
let rob = QuantitativeSemantics::eval(&spec, &trace).unwrap();
let rob = QuantitativeSemantics::eval::<Linear>(&spec, &trace).unwrap();
println!("{:#?}", rob);
let expected = Signal::from_iter(vec![
(Duration::from_secs_f64(0.0), 4.0),
@ -643,7 +642,7 @@ mod tests {
]);
let trace = MyTrace { signals };
let rob = QuantitativeSemantics::eval(&spec, &trace).unwrap();
let rob = QuantitativeSemantics::eval::<Constant>(&spec, &trace).unwrap();
let expected = Signal::from_iter(vec![(Duration::from_secs(0), 2), (Duration::from_secs(5), 2)])
.num_cast::<f64>()
@ -674,7 +673,7 @@ mod tests {
]);
let trace = MyTrace { signals };
let rob = QuantitativeSemantics::eval(&spec, &trace).unwrap();
let rob = QuantitativeSemantics::eval::<Constant>(&spec, &trace).unwrap();
let expected = Signal::from_iter(vec![(Duration::from_secs(4), 3), (Duration::from_secs(6), 3)])
.num_cast::<f64>()