feat!(core): Change Signal to be a sumtype

We want to be able to reason about if a signal is empty, constant, or sampled
at compile time without using any trait objects. Moreover, the core Argus
library shouldn't care about how it deals with interfacing with other languages
like Python. Thus, we remove the need for having an `AnySignal` type and what
not.
This commit is contained in:
Anand Balakrishnan 2023-04-14 10:53:38 -07:00
parent a6a3805107
commit 4431b79bcd
No known key found for this signature in database
10 changed files with 442 additions and 966 deletions

View file

@ -1,8 +1,8 @@
use num_traits::{Num, NumCast, Signed};
use num_traits::{NumCast, Signed};
use super::traits::{BaseSignal, LinearInterpolatable, SignalAbs};
use crate::signals::utils::{apply1, apply2, apply2_const, sync_with_intersection};
use crate::signals::{ConstantSignal, Signal};
use super::traits::{LinearInterpolatable, SignalAbs};
use crate::signals::utils::{apply1, apply2};
use crate::signals::Signal;
impl<T> core::ops::Neg for &Signal<T>
where
@ -16,21 +16,9 @@ where
}
}
impl<T> core::ops::Neg for &ConstantSignal<T>
where
T: Signed + Copy,
{
type Output = ConstantSignal<T>;
/// Negate the signal at each time point
fn neg(self) -> Self::Output {
ConstantSignal::new(self.value.neg())
}
}
impl<T> core::ops::Add for &Signal<T>
where
T: core::ops::Add<T, Output = T> + Num + NumCast + Copy + LinearInterpolatable,
T: core::ops::Add<T, Output = T> + Copy + LinearInterpolatable,
{
type Output = Signal<T>;
@ -40,45 +28,9 @@ where
}
}
impl<T> core::ops::Add for &ConstantSignal<T>
where
T: core::ops::Add<T, Output = T> + Num + Copy,
{
type Output = ConstantSignal<T>;
/// Add the given signal with another
fn add(self, rhs: Self) -> Self::Output {
ConstantSignal::<T>::new(self.value + rhs.value)
}
}
impl<T> core::ops::Add<&ConstantSignal<T>> for &Signal<T>
where
T: core::ops::Add<T, Output = T> + Num + NumCast + Copy + LinearInterpolatable,
{
type Output = Signal<T>;
/// Add the given signal with another
fn add(self, rhs: &ConstantSignal<T>) -> Self::Output {
apply2_const(self, rhs, |lhs, rhs| lhs + rhs)
}
}
impl<T> core::ops::Add<&Signal<T>> for &ConstantSignal<T>
where
T: core::ops::Add<T, Output = T> + Num + NumCast + Copy + LinearInterpolatable,
{
type Output = Signal<T>;
/// Add the given signal with another
fn add(self, rhs: &Signal<T>) -> Self::Output {
rhs + self
}
}
impl<T> core::ops::Mul for &Signal<T>
where
T: core::ops::Mul<T, Output = T> + Num + NumCast + Copy + LinearInterpolatable,
T: core::ops::Mul<T, Output = T> + Copy + LinearInterpolatable,
{
type Output = Signal<T>;
@ -88,46 +40,9 @@ where
}
}
impl<T> core::ops::Mul for &ConstantSignal<T>
where
T: core::ops::Mul<T, Output = T> + Num + Copy,
{
type Output = ConstantSignal<T>;
/// Multiply the given signal with another
fn mul(self, rhs: Self) -> Self::Output {
ConstantSignal::<T>::new(self.value * rhs.value)
}
}
impl<T> core::ops::Mul<&ConstantSignal<T>> for &Signal<T>
where
T: core::ops::Mul<T, Output = T> + Num + NumCast + Copy + LinearInterpolatable,
{
type Output = Signal<T>;
/// Multiply the given signal with another
fn mul(self, rhs: &ConstantSignal<T>) -> Self::Output {
apply2_const(self, rhs, |lhs, rhs| lhs * rhs)
}
}
impl<T> core::ops::Mul<&Signal<T>> for &ConstantSignal<T>
where
T: core::ops::Mul<T, Output = T> + Num + NumCast + Copy + LinearInterpolatable,
{
type Output = Signal<T>;
/// Multiply the given signal with another
fn mul(self, rhs: &Signal<T>) -> Self::Output {
rhs * self
}
}
impl<T> core::ops::Sub for &Signal<T>
where
T: core::ops::Sub<T, Output = T> + Num + NumCast + Copy + LinearInterpolatable + PartialOrd,
Signal<T>: BaseSignal<Value = T>,
T: core::ops::Sub<T, Output = T> + Copy + LinearInterpolatable + PartialOrd + NumCast,
{
type Output = Signal<T>;
@ -145,73 +60,7 @@ where
}
// the union of the sample points in self and other
let sync_points = sync_with_intersection(self, rhs).unwrap();
sync_points
.into_iter()
.map(|t| {
let lhs = self.interpolate_at(t, Linear).unwrap();
let rhs = rhs.interpolate_at(t, Linear).unwrap();
(t, lhs - rhs)
})
.collect()
}
}
impl<T> core::ops::Sub for &ConstantSignal<T>
where
T: core::ops::Sub<T, Output = T> + Num + Copy,
{
type Output = ConstantSignal<T>;
/// Subtract the given signal with another
fn sub(self, rhs: Self) -> Self::Output {
ConstantSignal::<T>::new(self.value - rhs.value)
}
}
impl<T> core::ops::Sub<&ConstantSignal<T>> for &Signal<T>
where
T: core::ops::Sub<T, Output = T> + Num + NumCast + Copy + LinearInterpolatable + PartialOrd,
Signal<T>: BaseSignal<Value = T>,
ConstantSignal<T>: BaseSignal<Value = T>,
{
type Output = Signal<T>;
/// Subtract the given signal with another
fn sub(self, rhs: &ConstantSignal<T>) -> Self::Output {
use super::InterpolationMethod::Linear;
// This has to be manually implemented and cannot use the apply2 functions.
// This is because if we have two signals that cross each other, then there is
// an intermediate point where the two signals are equal. This point must be
// added to the signal appropriately.
// the union of the sample points in self and other
let sync_points = sync_with_intersection(self, rhs).unwrap();
sync_points
.into_iter()
.map(|t| {
let lhs = self.interpolate_at(t, Linear).unwrap();
let rhs = rhs.interpolate_at(t, Linear).unwrap();
(t, lhs - rhs)
})
.collect()
}
}
impl<T> core::ops::Sub<&Signal<T>> for &ConstantSignal<T>
where
T: core::ops::Sub<T, Output = T> + Num + NumCast + Copy + LinearInterpolatable + PartialOrd,
{
type Output = Signal<T>;
/// Subtract the given signal with another
fn sub(self, rhs: &Signal<T>) -> Self::Output {
use super::InterpolationMethod::Linear;
// This has to be manually implemented and cannot use the apply2 functions.
// This is because if we have two signals that cross each other, then there is
// an intermediate point where the two signals are equal. This point must be
// added to the signal appropriately.
// the union of the sample points in self and other
let sync_points = sync_with_intersection(self, rhs).unwrap();
let sync_points = self.sync_with_intersection(rhs).unwrap();
sync_points
.into_iter()
.map(|t| {
@ -225,7 +74,7 @@ where
impl<T> core::ops::Div for &Signal<T>
where
T: core::ops::Div<T, Output = T> + Num + NumCast + Copy + LinearInterpolatable,
T: core::ops::Div<T, Output = T> + Copy + LinearInterpolatable,
{
type Output = Signal<T>;
@ -235,45 +84,9 @@ where
}
}
impl<T> core::ops::Div for &ConstantSignal<T>
where
T: core::ops::Div<T, Output = T> + Num + Copy,
{
type Output = ConstantSignal<T>;
/// Divide the given signal with another
fn div(self, rhs: Self) -> Self::Output {
ConstantSignal::<T>::new(self.value / rhs.value)
}
}
impl<T> core::ops::Div<&ConstantSignal<T>> for &Signal<T>
where
T: core::ops::Div<T, Output = T> + Num + NumCast + Copy + LinearInterpolatable,
{
type Output = Signal<T>;
/// Divide the given signal with another
fn div(self, rhs: &ConstantSignal<T>) -> Self::Output {
apply2_const(self, rhs, |lhs, rhs| lhs / rhs)
}
}
impl<T> core::ops::Div<&Signal<T>> for &ConstantSignal<T>
where
T: core::ops::Div<T, Output = T> + Num + NumCast + Copy + LinearInterpolatable,
{
type Output = Signal<T>;
/// Divide the given signal with another
fn div(self, rhs: &Signal<T>) -> Self::Output {
apply2_const(rhs, self, |rhs, lhs| lhs / rhs)
}
}
impl<T> num_traits::Pow<Self> for &Signal<T>
where
T: num_traits::Pow<T, Output = T> + Num + NumCast + Copy + LinearInterpolatable,
T: num_traits::Pow<T, Output = T> + Copy + LinearInterpolatable,
{
type Output = Signal<T>;
@ -284,17 +97,6 @@ where
}
macro_rules! signal_abs_impl {
(const $( $ty:ty ), *) => {
$(
impl SignalAbs for ConstantSignal<$ty> {
/// Return the absolute value for the signal
fn abs(&self) -> ConstantSignal<$ty> {
ConstantSignal::new(self.value.abs())
}
}
)*
};
($( $ty:ty ), *) => {
$(
impl SignalAbs for Signal<$ty> {
@ -315,12 +117,3 @@ impl SignalAbs for Signal<u64> {
apply1(self, |v| v)
}
}
signal_abs_impl!(const i64, f32, f64);
impl SignalAbs for ConstantSignal<u64> {
/// Return the absolute value for the signal
fn abs(&self) -> ConstantSignal<u64> {
ConstantSignal::new(self.value)
}
}