feat(core): Add SignalAbs trait for numeric signals

This commit is contained in:
Anand Balakrishnan 2023-04-04 11:55:48 -07:00
parent 55b7cdd075
commit 6e41380262
No known key found for this signature in database
5 changed files with 72 additions and 23 deletions

View file

@ -2,7 +2,7 @@ use itertools::Itertools;
use num_traits::{Num, NumCast}; use num_traits::{Num, NumCast};
use crate::signals::traits::SignalNumCast; use crate::signals::traits::SignalNumCast;
use crate::signals::{AnySignal, ConstantSignal, Signal}; use crate::signals::{ConstantSignal, Signal};
macro_rules! impl_cast { macro_rules! impl_cast {
($type:ty) => { ($type:ty) => {

View file

@ -2,7 +2,7 @@ use std::cmp::Ordering;
use num_traits::NumCast; use num_traits::NumCast;
use super::traits::{BaseSignal, LinearInterpolatable, SignalMinMax, SignalPartialOrd}; use super::traits::{BaseSignal, LinearInterpolatable, SignalMinMax, SignalPartialOrd, SignalSyncPoints};
use super::utils::sync_with_intersection; use super::utils::sync_with_intersection;
use super::{ConstantSignal, InterpolationMethod, Signal}; use super::{ConstantSignal, InterpolationMethod, Signal};
@ -109,13 +109,15 @@ where
} }
} }
impl<T> SignalMinMax for Signal<T> impl<T, Lhs, Rhs> SignalMinMax<Rhs> for Lhs
where where
T: PartialOrd + Copy + num_traits::NumCast + LinearInterpolatable, T: PartialOrd + Copy + num_traits::NumCast + LinearInterpolatable,
Lhs: SignalSyncPoints<Rhs> + BaseSignal<Value = T>,
Rhs: SignalSyncPoints<Self> + BaseSignal<Value = T>,
{ {
type Output = Signal<T>; type Output = Signal<T>;
fn min(&self, other: &Self) -> Self::Output { fn min(&self, other: &Rhs) -> Self::Output {
let time_points = sync_with_intersection(self, other).unwrap(); let time_points = sync_with_intersection(self, other).unwrap();
time_points time_points
.into_iter() .into_iter()
@ -131,7 +133,7 @@ where
.collect() .collect()
} }
fn max(&self, other: &Self) -> Self::Output { fn max(&self, other: &Rhs) -> Self::Output {
let time_points = sync_with_intersection(self, other).unwrap(); let time_points = sync_with_intersection(self, other).unwrap();
time_points time_points
.into_iter() .into_iter()

View file

@ -1,6 +1,6 @@
use num_traits::{Num, NumCast, Signed}; use num_traits::{Num, NumCast, Signed};
use super::traits::{BaseSignal, LinearInterpolatable}; use super::traits::{BaseSignal, LinearInterpolatable, SignalAbs};
use crate::signals::utils::{apply1, apply2, apply2_const, sync_with_intersection}; use crate::signals::utils::{apply1, apply2, apply2_const, sync_with_intersection};
use crate::signals::{ConstantSignal, Signal}; use crate::signals::{ConstantSignal, Signal};
@ -282,3 +282,45 @@ where
apply2(self, other, |lhs, rhs| lhs.pow(rhs)) apply2(self, other, |lhs, rhs| lhs.pow(rhs))
} }
} }
macro_rules! signal_abs_impl {
(const $( $ty:ty ), *) => {
$(
impl SignalAbs for ConstantSignal<$ty> {
/// Return the absolute value for the signal
fn abs(&self) -> ConstantSignal<$ty> {
ConstantSignal::new(self.value.abs())
}
}
)*
};
($( $ty:ty ), *) => {
$(
impl SignalAbs for Signal<$ty> {
/// Return the absolute value for each sample in the signal
fn abs(&self) -> Signal<$ty> {
apply1(self, |v| v.abs())
}
}
)*
};
}
signal_abs_impl!(i64, f32, f64);
impl SignalAbs for Signal<u64> {
/// Return the absolute value for each sample in the signal
fn abs(&self) -> Signal<u64> {
apply1(self, |v| v)
}
}
signal_abs_impl!(const i64, f32, f64);
impl SignalAbs for ConstantSignal<u64> {
/// Return the absolute value for the signal
fn abs(&self) -> ConstantSignal<u64> {
ConstantSignal::new(self.value)
}
}

View file

@ -266,20 +266,20 @@ where
} }
} }
impl<T> SignalSyncPoints<ConstantSignal<T>> for ConstantSignal<T> // impl<T> SignalSyncPoints<ConstantSignal<T>> for ConstantSignal<T>
where // where
T: Copy, // T: Copy,
Self: BaseSignal<Value = T>, // Self: BaseSignal<Value = T>,
{ // {
type Output<'a> = Empty<&'a Duration> // type Output<'a> = Empty<&'a Duration>
where // where
Self: 'a, // Self: 'a,
Self: 'a; // Self: 'a;
//
fn synchronization_points<'a>(&'a self, _other: &'a ConstantSignal<T>) -> Option<Self::Output<'a>> { // fn synchronization_points<'a>(&'a self, _other: &'a ConstantSignal<T>) ->
Some(core::iter::empty()) // Option<Self::Output<'a>> { Some(core::iter::empty())
} // }
} // }
impl<T> SignalSyncPoints<Signal<T>> for ConstantSignal<T> impl<T> SignalSyncPoints<Signal<T>> for ConstantSignal<T>
where where
@ -356,3 +356,8 @@ pub trait SignalNumCast {
fn to_f32(&self) -> Option<Self::Output<f32>>; fn to_f32(&self) -> Option<Self::Output<f32>>;
fn to_f64(&self) -> Option<Self::Output<f64>>; fn to_f64(&self) -> Option<Self::Output<f64>>;
} }
/// Trait for computing the absolute value of the samples in a signal
pub trait SignalAbs {
fn abs(&self) -> Self;
}

View file

@ -109,11 +109,11 @@ where
Some(return_points) Some(return_points)
} }
pub fn apply1<T, F>(signal: &Signal<T>, op: F) -> Signal<T> pub fn apply1<T, U, F>(signal: &Signal<T>, op: F) -> Signal<U>
where where
T: Copy, T: Copy,
F: Fn(T) -> T, F: Fn(T) -> U,
Signal<T>: std::iter::FromIterator<(Duration, T)>, Signal<U>: std::iter::FromIterator<(Duration, U)>,
{ {
signal.iter().map(|(t, v)| (*t, op(*v))).collect() signal.iter().map(|(t, v)| (*t, op(*v))).collect()
} }