refactor~(core): use traits and structs for interpolation

We have to now pass the interpolation method as a generic argument to methods.
This commit is contained in:
Anand Balakrishnan 2023-06-07 09:57:56 -04:00
parent 2b16ef9c40
commit 87afc11b90
No known key found for this signature in database
8 changed files with 314 additions and 306 deletions

View file

@ -1,6 +1,8 @@
use num_traits::{NumCast, Signed};
use num_traits::Signed;
use super::traits::{LinearInterpolatable, SignalAbs};
use super::interpolation::Linear;
use super::traits::SignalAbs;
use super::{FindIntersectionMethod, InterpolationMethod};
use crate::signals::utils::{apply1, apply2};
use crate::signals::Signal;
@ -18,37 +20,39 @@ where
impl<T> core::ops::Add for &Signal<T>
where
T: core::ops::Add<T, Output = T> + Copy + LinearInterpolatable,
T: core::ops::Add<T, Output = T> + Copy,
Linear: InterpolationMethod<T>,
{
type Output = Signal<T>;
/// Add the given signal with another
fn add(self, rhs: Self) -> Self::Output {
apply2(self, rhs, |lhs, rhs| lhs + rhs)
apply2::<_, _, _, Linear>(self, rhs, |lhs, rhs| lhs + rhs)
}
}
impl<T> core::ops::Mul for &Signal<T>
where
T: core::ops::Mul<T, Output = T> + Copy + LinearInterpolatable,
T: core::ops::Mul<T, Output = T> + Copy,
Linear: InterpolationMethod<T>,
{
type Output = Signal<T>;
/// Multiply the given signal with another
fn mul(self, rhs: Self) -> Self::Output {
apply2(self, rhs, |lhs, rhs| lhs * rhs)
apply2::<_, _, _, Linear>(self, rhs, |lhs, rhs| lhs * rhs)
}
}
impl<T> core::ops::Sub for &Signal<T>
where
T: core::ops::Sub<T, Output = T> + Copy + LinearInterpolatable + PartialOrd + NumCast,
T: core::ops::Sub<T, Output = T> + Copy + PartialOrd,
Linear: InterpolationMethod<T> + FindIntersectionMethod<T>,
{
type Output = Signal<T>;
/// Subtract the given signal with another
fn sub(self, rhs: Self) -> Self::Output {
use super::InterpolationMethod::Linear;
// This has to be manually implemented and cannot use the apply2 functions.
// This is because if we have two signals that cross each other, then there is
// an intermediate point where the two signals are equal. This point must be
@ -60,12 +64,12 @@ where
}
// the union of the sample points in self and other
let sync_points = self.sync_with_intersection(rhs).unwrap();
let sync_points = self.sync_with_intersection::<Linear>(rhs).unwrap();
sync_points
.into_iter()
.map(|t| {
let lhs = self.interpolate_at(t, Linear).unwrap();
let rhs = rhs.interpolate_at(t, Linear).unwrap();
let lhs = self.interpolate_at::<Linear>(t).unwrap();
let rhs = rhs.interpolate_at::<Linear>(t).unwrap();
(t, lhs - rhs)
})
.collect()
@ -74,25 +78,27 @@ where
impl<T> core::ops::Div for &Signal<T>
where
T: core::ops::Div<T, Output = T> + Copy + LinearInterpolatable,
T: core::ops::Div<T, Output = T> + Copy,
Linear: InterpolationMethod<T>,
{
type Output = Signal<T>;
/// Divide the given signal with another
fn div(self, rhs: Self) -> Self::Output {
apply2(self, rhs, |lhs, rhs| lhs / rhs)
apply2::<_, _, _, Linear>(self, rhs, |lhs, rhs| lhs / rhs)
}
}
impl<T> num_traits::Pow<Self> for &Signal<T>
where
T: num_traits::Pow<T, Output = T> + Copy + LinearInterpolatable,
T: num_traits::Pow<T, Output = T> + Copy,
Linear: InterpolationMethod<T>,
{
type Output = Signal<T>;
/// Returns the values in `self` to the power of the values in `other`
fn pow(self, other: Self) -> Self::Output {
apply2(self, other, |lhs, rhs| lhs.pow(rhs))
apply2::<_, _, _, Linear>(self, other, |lhs, rhs| lhs.pow(rhs))
}
}