feat: add Sub and Abs expression nodes

This commit is contained in:
Anand Balakrishnan 2023-04-04 14:32:37 -07:00
parent 4dc6effbde
commit 2b447409a1
No known key found for this signature in database
2 changed files with 44 additions and 22 deletions

View file

@ -26,8 +26,11 @@ pub enum NumExpr {
Neg { arg: Box<NumExpr> }, Neg { arg: Box<NumExpr> },
Add { args: Vec<NumExpr> }, Add { args: Vec<NumExpr> },
Sub { lhs: Box<NumExpr>, rhs: Box<NumExpr> },
Mul { args: Vec<NumExpr> }, Mul { args: Vec<NumExpr> },
Div { dividend: Box<NumExpr>, divisor: Box<NumExpr> }, Div { dividend: Box<NumExpr>, divisor: Box<NumExpr> },
Abs { arg: Box<NumExpr> },
} }
impl Expr for NumExpr { impl Expr for NumExpr {

View file

@ -5,20 +5,23 @@ use crate::Trace;
macro_rules! signal_num_op_impl { macro_rules! signal_num_op_impl {
// Unary numeric opeartions // Unary numeric opeartions
(- $signal:ident) => {{ ($op:ident, $signal:ident, [$( $type:ident ),*]) => {
paste::paste! {
{
use argus_core::prelude::*; use argus_core::prelude::*;
use AnySignal::*; use AnySignal::*;
match $signal { match $signal {
Bool(_) | ConstBool(_) => panic!("cannot perform unary operation (-) on Boolean signals"), $(
Int(signal) => AnySignal::from(-(&signal)), [< $type >](signal) => AnySignal::from(signal.$op()),
ConstInt(signal) => AnySignal::from(-(&signal)), [<Const $type >](signal) => AnySignal::from(signal.$op()),
UInt(_) | ConstUInt(_) => panic!("cannot perform unary operation (-) on unsigned integer signals"), )*
Float(signal) => AnySignal::from(-(&signal)), _ => panic!("cannot perform unary operation ({})", stringify!($op)),
ConstFloat(signal) => AnySignal::from(-(&signal)),
} }
}}; }
}
};
($lhs:ident $op:tt $rhs:ident, [$( $type:ident ),*]) => { ($op:ident, $lhs:ident, $rhs:ident, [$( $type:ident ),*]) => {
paste::paste!{ paste::paste!{
{ {
use argus_core::prelude::*; use argus_core::prelude::*;
@ -26,10 +29,10 @@ macro_rules! signal_num_op_impl {
match ($lhs, $rhs) { match ($lhs, $rhs) {
(Bool(_), _) | (ConstBool(_), _) | (_, Bool(_)) | (_, ConstBool(_)) => panic!("cannot perform numeric operation {} for boolean arguments", stringify!($op)), (Bool(_), _) | (ConstBool(_), _) | (_, Bool(_)) | (_, ConstBool(_)) => panic!("cannot perform numeric operation {} for boolean arguments", stringify!($op)),
$( $(
([<$type >](lhs), [< $type >](rhs)) => AnySignal::from(&lhs $op &rhs), ([<$type >](lhs), [< $type >](rhs)) => AnySignal::from(lhs.$op(&rhs)),
([<$type >](lhs), [< Const $type >](rhs)) => AnySignal::from(&lhs $op &rhs), ([<$type >](lhs), [< Const $type >](rhs)) => AnySignal::from(lhs.$op(&rhs)),
([<Const $type >](lhs), [< $type >](rhs)) => AnySignal::from(&lhs $op &rhs), ([<Const $type >](lhs), [< $type >](rhs)) => AnySignal::from(lhs.$op(&rhs)),
([<Const $type >](lhs), [< Const $type >](rhs)) => AnySignal::from(&lhs $op &rhs), ([<Const $type >](lhs), [< Const $type >](rhs)) => AnySignal::from(lhs.$op(&rhs)),
)* )*
_ => panic!("mismatched argument types for {} operation", stringify!($op)), _ => panic!("mismatched argument types for {} operation", stringify!($op)),
} }
@ -38,9 +41,9 @@ macro_rules! signal_num_op_impl {
}; };
// Binary numeric opeartions // Binary numeric opeartions
($lhs:ident $op:tt $rhs:ident) => { ($op:ident, $lhs:ident, $rhs:ident) => {
signal_num_op_impl!( signal_num_op_impl!(
$lhs $op $rhs, $op, $lhs, $rhs,
[Int, UInt, Float] [Int, UInt, Float]
) )
}; };
@ -53,6 +56,9 @@ pub struct NumExprEval;
impl NumExprEval { impl NumExprEval {
pub fn eval(root: &NumExpr, trace: &impl Trace) -> AnySignal { pub fn eval(root: &NumExpr, trace: &impl Trace) -> AnySignal {
use core::ops::{Add, Div, Mul, Neg, Sub};
use argus_core::signals::traits::SignalAbs;
match root { match root {
NumExpr::IntLit(val) => ConstantSignal::new(*val).into(), NumExpr::IntLit(val) => ConstantSignal::new(*val).into(),
NumExpr::UIntLit(val) => ConstantSignal::new(*val).into(), NumExpr::UIntLit(val) => ConstantSignal::new(*val).into(),
@ -63,20 +69,33 @@ impl NumExprEval {
} }
NumExpr::Neg { arg } => { NumExpr::Neg { arg } => {
let arg_sig = Self::eval(arg, trace); let arg_sig = Self::eval(arg, trace);
signal_num_op_impl!(-arg_sig) signal_num_op_impl!(neg, arg_sig, [Int, Float])
} }
NumExpr::Add { args } => { NumExpr::Add { args } => {
let args_signals = args.iter().map(|arg| Self::eval(arg, trace)); let args_signals = args.iter().map(|arg| Self::eval(arg, trace));
args_signals.reduce(|acc, arg| signal_num_op_impl!(acc + arg)).unwrap() args_signals
.reduce(|acc, arg| signal_num_op_impl!(add, acc, arg))
.unwrap()
}
NumExpr::Sub { lhs, rhs } => {
let lhs = Self::eval(lhs, trace);
let rhs = Self::eval(rhs, trace);
signal_num_op_impl!(sub, lhs, rhs)
} }
NumExpr::Mul { args } => { NumExpr::Mul { args } => {
let args_signals = args.iter().map(|arg| Self::eval(arg, trace)); let args_signals = args.iter().map(|arg| Self::eval(arg, trace));
args_signals.reduce(|acc, arg| signal_num_op_impl!(acc * arg)).unwrap() args_signals
.reduce(|acc, arg| signal_num_op_impl!(mul, acc, arg))
.unwrap()
} }
NumExpr::Div { dividend, divisor } => { NumExpr::Div { dividend, divisor } => {
let dividend = Self::eval(dividend, trace); let dividend = Self::eval(dividend, trace);
let divisor = Self::eval(divisor, trace); let divisor = Self::eval(divisor, trace);
signal_num_op_impl!(dividend / divisor) signal_num_op_impl!(div, dividend, divisor)
}
NumExpr::Abs { arg } => {
let arg = Self::eval(arg, trace);
signal_num_op_impl!(abs, arg, [Int, UInt, Float])
} }
} }
} }