fix(argus): explicitly derive AnyExpr

This commit is contained in:
Anand Balakrishnan 2023-10-15 11:14:44 -07:00
parent ecef2b266d
commit 0c69d52379
No known key found for this signature in database
2 changed files with 149 additions and 6 deletions

View file

@ -36,8 +36,9 @@ pub trait IsNumExpr: AnyExpr + Into<NumExpr> {}
pub trait IsBoolExpr: AnyExpr + Into<BoolExpr> {} pub trait IsBoolExpr: AnyExpr + Into<BoolExpr> {}
/// All expressions that are numeric /// All expressions that are numeric
#[derive(Clone, Debug, PartialEq, argus_derive::NumExpr, derive_more::Display)] #[derive(
#[enum_dispatch(AnyExpr)] Clone, Debug, PartialEq, argus_derive::NumExpr, derive_more::Display, derive_more::From, derive_more::TryInto,
)]
pub enum NumExpr { pub enum NumExpr {
/// A signed integer literal /// A signed integer literal
IntLit(IntLit), IntLit(IntLit),
@ -72,9 +73,35 @@ impl NumExpr {
} }
} }
impl AnyExpr for NumExpr {
fn is_numeric(&self) -> bool {
true
}
fn is_boolean(&self) -> bool {
false
}
fn args(&self) -> Vec<ExprRef<'_>> {
match self {
NumExpr::IntLit(expr) => expr.args(),
NumExpr::UIntLit(expr) => expr.args(),
NumExpr::FloatLit(expr) => expr.args(),
NumExpr::IntVar(expr) => expr.args(),
NumExpr::UIntVar(expr) => expr.args(),
NumExpr::FloatVar(expr) => expr.args(),
NumExpr::Neg(expr) => expr.args(),
NumExpr::Add(expr) => expr.args(),
NumExpr::Sub(expr) => expr.args(),
NumExpr::Mul(expr) => expr.args(),
NumExpr::Div(expr) => expr.args(),
NumExpr::Abs(expr) => expr.args(),
}
}
}
/// All expressions that are evaluated to be of type `bool` /// All expressions that are evaluated to be of type `bool`
#[derive(Clone, Debug, PartialEq, argus_derive::BoolExpr, derive_more::Display)] #[derive(
#[enum_dispatch(AnyExpr)] Clone, Debug, PartialEq, argus_derive::BoolExpr, derive_more::Display, derive_more::From, derive_more::TryInto,
)]
pub enum BoolExpr { pub enum BoolExpr {
/// A `bool` literal /// A `bool` literal
BoolLit(BoolLit), BoolLit(BoolLit),
@ -129,6 +156,30 @@ impl BoolExpr {
} }
} }
impl AnyExpr for BoolExpr {
fn is_boolean(&self) -> bool {
true
}
fn is_numeric(&self) -> bool {
false
}
fn args(&self) -> Vec<ExprRef<'_>> {
match self {
BoolExpr::BoolLit(expr) => expr.args(),
BoolExpr::BoolVar(expr) => expr.args(),
BoolExpr::Cmp(expr) => expr.args(),
BoolExpr::Not(expr) => expr.args(),
BoolExpr::And(expr) => expr.args(),
BoolExpr::Or(expr) => expr.args(),
BoolExpr::Next(expr) => expr.args(),
BoolExpr::Oracle(expr) => expr.args(),
BoolExpr::Always(expr) => expr.args(),
BoolExpr::Eventually(expr) => expr.args(),
BoolExpr::Until(expr) => expr.args(),
}
}
}
/// A reference to an expression (either [`BoolExpr`] or [`NumExpr`]). /// A reference to an expression (either [`BoolExpr`] or [`NumExpr`]).
#[derive(Clone, Copy, Debug, derive_more::From)] #[derive(Clone, Copy, Debug, derive_more::From)]
pub enum ExprRef<'a> { pub enum ExprRef<'a> {
@ -139,8 +190,7 @@ pub enum ExprRef<'a> {
} }
/// An expression (either [`BoolExpr`] or [`NumExpr`]) /// An expression (either [`BoolExpr`] or [`NumExpr`])
#[derive(Clone, Debug, derive_more::Display)] #[derive(Clone, Debug, derive_more::Display, derive_more::From, derive_more::TryInto)]
#[enum_dispatch(AnyExpr)]
pub enum Expr { pub enum Expr {
/// A reference to a [`BoolExpr`] /// A reference to a [`BoolExpr`]
Bool(BoolExpr), Bool(BoolExpr),
@ -148,6 +198,23 @@ pub enum Expr {
Num(NumExpr), Num(NumExpr),
} }
impl AnyExpr for Expr {
fn is_numeric(&self) -> bool {
matches!(self, Expr::Num(_))
}
fn is_boolean(&self) -> bool {
matches!(self, Expr::Bool(_))
}
fn args(&self) -> Vec<ExprRef<'_>> {
match self {
Expr::Bool(expr) => expr.args(),
Expr::Num(expr) => expr.args(),
}
}
}
/// Expression builder /// Expression builder
/// ///
/// The `ExprBuilder` is a factory structure that deals with the creation of /// The `ExprBuilder` is a factory structure that deals with the creation of
@ -636,4 +703,26 @@ mod tests {
test_bool_binop!(And, bitand with &); test_bool_binop!(And, bitand with &);
test_bool_binop!(Or, bitor with |); test_bool_binop!(Or, bitor with |);
#[test]
fn is_numeric() {
let mut builder = ExprBuilder::new();
let a = builder.int_const(10);
let b = builder.int_var("b".to_owned()).unwrap();
let spec = a + b;
assert!(spec.is_numeric());
assert!(!spec.is_boolean());
}
#[test]
fn is_boolean() {
let mut builder = ExprBuilder::new();
let a = builder.bool_const(true);
let b = builder.bool_var("b".to_owned()).unwrap();
let spec = a & b;
assert!(!spec.is_numeric());
assert!(spec.is_boolean());
}
} }

View file

@ -3,6 +3,7 @@
use std::time::Duration; use std::time::Duration;
use crate::core::expr::ExprBuilder; use crate::core::expr::ExprBuilder;
use crate::core::AnyExpr;
mod lexer; mod lexer;
mod syntax; mod syntax;
@ -137,6 +138,7 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
match op { match op {
syntax::UnaryOps::Neg => { syntax::UnaryOps::Neg => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(arg.is_numeric(), "expected numeric expression, got {:?}", arg);
let crate::core::expr::Expr::Num(arg) = arg else { let crate::core::expr::Expr::Num(arg) = arg else {
unreachable!("- must have numeric expression argument"); unreachable!("- must have numeric expression argument");
}; };
@ -144,6 +146,7 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::UnaryOps::Not => { syntax::UnaryOps::Not => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(arg.is_boolean(), "expected boolean expression, got {:?}", arg);
let crate::core::expr::Expr::Bool(arg) = arg else { let crate::core::expr::Expr::Bool(arg) = arg else {
unreachable!("`Not` must have boolean expression argument"); unreachable!("`Not` must have boolean expression argument");
}; };
@ -151,6 +154,7 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::UnaryOps::Next => { syntax::UnaryOps::Next => {
use core::ops::Bound; use core::ops::Bound;
assert!(arg.is_boolean(), "expected boolean expression, got {:?}", arg);
let crate::core::expr::Expr::Bool(arg) = arg else { let crate::core::expr::Expr::Bool(arg) = arg else {
unreachable!("`Next` must have boolean expression argument"); unreachable!("`Next` must have boolean expression argument");
}; };
@ -171,6 +175,7 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
} }
syntax::UnaryOps::Always => { syntax::UnaryOps::Always => {
assert!(arg.is_boolean(), "expected boolean expression, got {:?}", arg);
let crate::core::expr::Expr::Bool(arg) = arg else { let crate::core::expr::Expr::Bool(arg) = arg else {
unreachable!("`Always` must have boolean expression argument"); unreachable!("`Always` must have boolean expression argument");
}; };
@ -180,6 +185,7 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
} }
syntax::UnaryOps::Eventually => { syntax::UnaryOps::Eventually => {
assert!(arg.is_boolean(), "expected boolean expression, got {:?}", arg);
let crate::core::expr::Expr::Bool(arg) = arg else { let crate::core::expr::Expr::Bool(arg) = arg else {
unreachable!("`Eventually` must have boolean expression argument"); unreachable!("`Eventually` must have boolean expression argument");
}; };
@ -202,6 +208,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
match op { match op {
syntax::BinaryOps::Add => { syntax::BinaryOps::Add => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_numeric(), "expected numeric expression, got {:?}", lhs);
assert!(rhs.is_numeric(), "expected numeric expression, got {:?}", rhs);
let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else {
unreachable!("`Add` must have numeric expression arguments"); unreachable!("`Add` must have numeric expression arguments");
}; };
@ -211,6 +220,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::BinaryOps::Sub => { syntax::BinaryOps::Sub => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_numeric(), "expected numeric expression, got {:?}", lhs);
assert!(rhs.is_numeric(), "expected numeric expression, got {:?}", rhs);
let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else {
unreachable!("`Sub` must have numeric expression arguments"); unreachable!("`Sub` must have numeric expression arguments");
}; };
@ -218,6 +230,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::BinaryOps::Mul => { syntax::BinaryOps::Mul => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_numeric(), "expected numeric expression, got {:?}", lhs);
assert!(rhs.is_numeric(), "expected numeric expression, got {:?}", rhs);
let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else {
unreachable!("`Mul` must have numeric expression arguments"); unreachable!("`Mul` must have numeric expression arguments");
}; };
@ -227,6 +242,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::BinaryOps::Div => { syntax::BinaryOps::Div => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_numeric(), "expected numeric expression, got {:?}", lhs);
assert!(rhs.is_numeric(), "expected numeric expression, got {:?}", rhs);
let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else {
unreachable!("`Div` must have numeric expression arguments"); unreachable!("`Div` must have numeric expression arguments");
}; };
@ -234,6 +252,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::BinaryOps::Lt => { syntax::BinaryOps::Lt => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_numeric(), "expected numeric expression, got {:?}", lhs);
assert!(rhs.is_numeric(), "expected numeric expression, got {:?}", rhs);
let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else {
unreachable!("Relational operation must have numeric expression arguments"); unreachable!("Relational operation must have numeric expression arguments");
}; };
@ -241,6 +262,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::BinaryOps::Le => { syntax::BinaryOps::Le => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_numeric(), "expected numeric expression, got {:?}", lhs);
assert!(rhs.is_numeric(), "expected numeric expression, got {:?}", rhs);
let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else {
unreachable!("Relational operation must have numeric expression arguments"); unreachable!("Relational operation must have numeric expression arguments");
}; };
@ -248,6 +272,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::BinaryOps::Gt => { syntax::BinaryOps::Gt => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_numeric(), "expected numeric expression, got {:?}", lhs);
assert!(rhs.is_numeric(), "expected numeric expression, got {:?}", rhs);
let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else {
unreachable!("Relational operation must have numeric expression arguments"); unreachable!("Relational operation must have numeric expression arguments");
}; };
@ -255,6 +282,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::BinaryOps::Ge => { syntax::BinaryOps::Ge => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_numeric(), "expected numeric expression, got {:?}", lhs);
assert!(rhs.is_numeric(), "expected numeric expression, got {:?}", rhs);
let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else {
unreachable!("Relational operation must have numeric expression arguments"); unreachable!("Relational operation must have numeric expression arguments");
}; };
@ -262,6 +292,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::BinaryOps::Eq => { syntax::BinaryOps::Eq => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_numeric(), "expected numeric expression, got: {}", lhs);
assert!(rhs.is_numeric(), "expected numeric expression, got: {}", rhs);
let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else {
unreachable!("Relational operation must have numeric expression arguments"); unreachable!("Relational operation must have numeric expression arguments");
}; };
@ -269,6 +302,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::BinaryOps::Neq => { syntax::BinaryOps::Neq => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_numeric(), "expected numeric expression, got {:?}", lhs);
assert!(rhs.is_numeric(), "expected numeric expression, got {:?}", rhs);
let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Num(lhs), crate::core::expr::Expr::Num(rhs)) = (lhs, rhs) else {
unreachable!("Relational operation must have numeric expression arguments"); unreachable!("Relational operation must have numeric expression arguments");
}; };
@ -276,6 +312,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::BinaryOps::And => { syntax::BinaryOps::And => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_boolean(), "expected boolean expression, got {:?}", lhs);
assert!(rhs.is_boolean(), "expected boolean expression, got {:?}", rhs);
let (crate::core::expr::Expr::Bool(lhs), crate::core::expr::Expr::Bool(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Bool(lhs), crate::core::expr::Expr::Bool(rhs)) = (lhs, rhs) else {
unreachable!("`And` must have boolean expression arguments"); unreachable!("`And` must have boolean expression arguments");
}; };
@ -285,6 +324,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::BinaryOps::Or => { syntax::BinaryOps::Or => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_boolean(), "expected boolean expression, got {:?}", lhs);
assert!(rhs.is_boolean(), "expected boolean expression, got {:?}", rhs);
let (crate::core::expr::Expr::Bool(lhs), crate::core::expr::Expr::Bool(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Bool(lhs), crate::core::expr::Expr::Bool(rhs)) = (lhs, rhs) else {
unreachable!("`Or` must have boolean expression arguments"); unreachable!("`Or` must have boolean expression arguments");
}; };
@ -294,6 +336,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::BinaryOps::Implies => { syntax::BinaryOps::Implies => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_boolean(), "expected boolean expression, got {:?}", lhs);
assert!(rhs.is_boolean(), "expected boolean expression, got {:?}", rhs);
let (crate::core::expr::Expr::Bool(lhs), crate::core::expr::Expr::Bool(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Bool(lhs), crate::core::expr::Expr::Bool(rhs)) = (lhs, rhs) else {
unreachable!("`Implies` must have boolean expression arguments"); unreachable!("`Implies` must have boolean expression arguments");
}; };
@ -303,6 +348,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::BinaryOps::Equiv => { syntax::BinaryOps::Equiv => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_boolean(), "expected boolean expression, got {:?}", lhs);
assert!(rhs.is_boolean(), "expected boolean expression, got {:?}", rhs);
let (crate::core::expr::Expr::Bool(lhs), crate::core::expr::Expr::Bool(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Bool(lhs), crate::core::expr::Expr::Bool(rhs)) = (lhs, rhs) else {
unreachable!("`Equiv` must have boolean expression arguments"); unreachable!("`Equiv` must have boolean expression arguments");
}; };
@ -312,6 +360,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
} }
syntax::BinaryOps::Xor => { syntax::BinaryOps::Xor => {
assert!(interval.is_none()); assert!(interval.is_none());
assert!(lhs.is_boolean(), "expected boolean expression, got {:?}", lhs);
assert!(rhs.is_boolean(), "expected boolean expression, got {:?}", rhs);
let (crate::core::expr::Expr::Bool(lhs), crate::core::expr::Expr::Bool(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Bool(lhs), crate::core::expr::Expr::Bool(rhs)) = (lhs, rhs) else {
unreachable!("`Xor` must have boolean expression arguments"); unreachable!("`Xor` must have boolean expression arguments");
}; };
@ -320,6 +371,9 @@ fn ast_to_expr<'tokens, 'src: 'tokens>(
.map_err(|err| Rich::custom(span, err.to_string())) .map_err(|err| Rich::custom(span, err.to_string()))
} }
syntax::BinaryOps::Until => { syntax::BinaryOps::Until => {
assert!(lhs.is_boolean(), "expected boolean expression, got {:?}", lhs);
assert!(rhs.is_boolean(), "expected boolean expression, got {:?}", rhs);
let (crate::core::expr::Expr::Bool(lhs), crate::core::expr::Expr::Bool(rhs)) = (lhs, rhs) else { let (crate::core::expr::Expr::Bool(lhs), crate::core::expr::Expr::Bool(rhs)) = (lhs, rhs) else {
unreachable!("`Until` must have boolean expression arguments"); unreachable!("`Until` must have boolean expression arguments");
}; };