refactor~(core): use traits and structs for interpolation

We have to now pass the interpolation method as a generic argument to methods.
This commit is contained in:
Anand Balakrishnan 2023-06-07 09:57:56 -04:00
parent 2b16ef9c40
commit 87afc11b90
No known key found for this signature in database
8 changed files with 314 additions and 306 deletions

View file

@ -1,3 +1,4 @@
use super::interpolation::Linear;
use crate::signals::utils::{apply1, apply2};
use crate::signals::Signal;
@ -13,7 +14,7 @@ impl core::ops::BitAnd<Self> for &Signal<bool> {
type Output = Signal<bool>;
fn bitand(self, other: Self) -> Self::Output {
apply2(self, other, |lhs, rhs| lhs && rhs)
apply2::<_, _, _, Linear>(self, other, |lhs, rhs| lhs && rhs)
}
}
@ -21,6 +22,6 @@ impl core::ops::BitOr<Self> for &Signal<bool> {
type Output = Signal<bool>;
fn bitor(self, other: Self) -> Self::Output {
apply2(self, other, |lhs, rhs| lhs || rhs)
apply2::<_, _, _, Linear>(self, other, |lhs, rhs| lhs || rhs)
}
}

View file

@ -1,30 +1,29 @@
use std::cmp::Ordering;
use num_traits::NumCast;
use super::traits::{LinearInterpolatable, SignalMinMax, SignalPartialOrd};
use super::{InterpolationMethod, Signal};
use super::interpolation::Linear;
use super::traits::{SignalMinMax, SignalPartialOrd};
use super::{FindIntersectionMethod, InterpolationMethod, Signal};
impl<T> SignalPartialOrd<Self> for Signal<T>
where
T: PartialOrd + Copy + std::fmt::Debug + NumCast + LinearInterpolatable,
T: PartialOrd + Copy,
Linear: InterpolationMethod<T> + FindIntersectionMethod<T>,
{
fn signal_cmp<F>(&self, other: &Self, op: F) -> Option<Signal<bool>>
where
F: Fn(Ordering) -> bool,
{
use super::InterpolationMethod::Linear;
// This has to be manually implemented and cannot use the apply2 functions.
// This is because if we have two signals that cross each other, then there is
// an intermediate point where the two signals are equal. This point must be
// added to the signal appropriately.
// the union of the sample points in self and other
let sync_points = self.sync_with_intersection(other)?;
let sync_points = self.sync_with_intersection::<Linear>(other)?;
let sig: Option<Signal<bool>> = sync_points
.into_iter()
.map(|t| {
let lhs = self.interpolate_at(t, Linear).unwrap();
let rhs = other.interpolate_at(t, Linear).unwrap();
let lhs = self.interpolate_at::<Linear>(t).unwrap();
let rhs = other.interpolate_at::<Linear>(t).unwrap();
let cmp = lhs.partial_cmp(&rhs);
cmp.map(|v| (t, op(v)))
})
@ -35,17 +34,18 @@ where
impl<T> SignalMinMax<Self> for Signal<T>
where
T: PartialOrd + Copy + LinearInterpolatable + NumCast,
T: PartialOrd + Copy,
Linear: InterpolationMethod<T> + FindIntersectionMethod<T>,
{
type Output = Signal<T>;
fn min(&self, other: &Self) -> Self::Output {
let time_points = self.sync_with_intersection(other).unwrap();
let time_points = self.sync_with_intersection::<Linear>(other).unwrap();
time_points
.into_iter()
.map(|t| {
let lhs = self.interpolate_at(t, InterpolationMethod::Linear).unwrap();
let rhs = other.interpolate_at(t, InterpolationMethod::Linear).unwrap();
let lhs = self.interpolate_at::<Linear>(t).unwrap();
let rhs = other.interpolate_at::<Linear>(t).unwrap();
if lhs < rhs {
(t, lhs)
} else {
@ -56,12 +56,12 @@ where
}
fn max(&self, other: &Self) -> Self::Output {
let time_points = self.sync_with_intersection(other).unwrap();
let time_points = self.sync_with_intersection::<Linear>(other).unwrap();
time_points
.into_iter()
.map(|t| {
let lhs = self.interpolate_at(t, InterpolationMethod::Linear).unwrap();
let rhs = other.interpolate_at(t, InterpolationMethod::Linear).unwrap();
let lhs = self.interpolate_at::<Linear>(t).unwrap();
let rhs = other.interpolate_at::<Linear>(t).unwrap();
if lhs > rhs {
(t, lhs)
} else {

View file

@ -0,0 +1,177 @@
//! Interpolation methods
use std::time::Duration;
use super::utils::Neighborhood;
use super::{FindIntersectionMethod, InterpolationMethod, Sample};
/// Constant interpolation.
///
/// Here, the previous signal value is propagated to the requested time point.
pub struct Constant;
impl<T: Clone> InterpolationMethod<T> for Constant {
fn at(a: &Sample<T>, b: &Sample<T>, time: Duration) -> Option<T> {
if time == b.time {
Some(b.value.clone())
} else if a.time <= time && time < b.time {
Some(a.value.clone())
} else {
None
}
}
}
/// Nearest interpolation.
///
/// Here, the signal value from the nearest sample (time-wise) is propagated to the
/// requested time point.
pub struct Nearest;
impl<T: Clone> InterpolationMethod<T> for Nearest {
fn at(a: &super::Sample<T>, b: &super::Sample<T>, time: std::time::Duration) -> Option<T> {
if time < a.time || time > b.time {
// `time` is outside the segments.
None
} else if (b.time - time) > (time - a.time) {
// a is closer to the required time than b
Some(a.value.clone())
} else {
// b is closer
Some(b.value.clone())
}
}
}
/// Linear interpolation.
///
/// Here, linear interpolation is performed to estimate the sample at the time point
/// between two samples.
pub struct Linear;
impl InterpolationMethod<bool> for Linear {
fn at(a: &Sample<bool>, b: &Sample<bool>, time: Duration) -> Option<bool> {
if a.time < time && time < b.time {
// We can't linear interpolate a boolean, so we return the previous.
Some(a.value)
} else {
None
}
}
}
impl FindIntersectionMethod<bool> for Linear {
fn find_intersection(a: &Neighborhood<bool>, b: &Neighborhood<bool>) -> Sample<bool> {
let Sample { time: ta1, value: ya1 } = a.first.unwrap();
let Sample { time: ta2, value: ya2 } = a.second.unwrap();
let Sample { time: tb1, value: yb1 } = b.first.unwrap();
let Sample { time: tb2, value: yb2 } = b.second.unwrap();
let left_cmp = ya1.cmp(&yb1);
let right_cmp = ya2.cmp(&yb2);
if left_cmp.is_eq() {
// They already intersect, so we return the inner time-point
if ta1 < tb1 {
Sample { time: tb1, value: yb1 }
} else {
Sample { time: ta1, value: ya1 }
}
} else if right_cmp.is_eq() {
// They intersect at the end, so we return the outer time-point, as that is
// when they become equal.
if ta2 < tb2 {
Sample { time: tb2, value: yb2 }
} else {
Sample { time: ta2, value: ya2 }
}
} else {
// The switched, so the one that switched earlier will intersect with the
// other.
// So, we find the one that has a lower time point, i.e., the inner one.
if ta2 < tb2 {
Sample { time: ta2, value: ya2 }
} else {
Sample { time: tb2, value: yb2 }
}
}
}
}
macro_rules! interpolate_for_num {
($ty:ty) => {
impl InterpolationMethod<$ty> for Linear {
fn at(first: &Sample<$ty>, second: &Sample<$ty>, time: Duration) -> Option<$ty> {
use num_traits::cast;
// We will need to cast the samples to f64 values (along with the time
// window) to be able to interpolate correctly.
// TODO(anand): Verify this works.
let t1 = first.time.as_secs_f64();
let t2 = second.time.as_secs_f64();
let at = time.as_secs_f64();
assert!((t1..=t2).contains(&at));
// We need to do stable linear interpolation
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p0811r3.html
let a: f64 = cast(first.value).unwrap();
let b: f64 = cast(second.value).unwrap();
// Set t to a value in [0, 1]
let t = (at - t1) / (t2 - t1);
assert!((0.0..=1.0).contains(&t));
let val = if (a <= 0.0 && b >= 0.0) || (a >= 0.0 && b <= 0.0) {
t * b + (1.0 - t) * a
} else if t == 1.0 {
b
} else {
a + t * (b - a)
};
cast(val)
}
}
impl FindIntersectionMethod<$ty> for Linear {
fn find_intersection(a: &Neighborhood<$ty>, b: &Neighborhood<$ty>) -> Sample<$ty> {
// https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection#Given_two_points_on_each_line
use num_traits::cast;
let Sample { time: t1, value: y1 } = a.first.unwrap();
let Sample { time: t2, value: y2 } = a.second.unwrap();
let Sample { time: t3, value: y3 } = b.first.unwrap();
let Sample { time: t4, value: y4 } = b.second.unwrap();
let t1 = t1.as_secs_f64();
let t2 = t2.as_secs_f64();
let t3 = t3.as_secs_f64();
let t4 = t4.as_secs_f64();
let y1: f64 = cast(y1).unwrap();
let y2: f64 = cast(y2).unwrap();
let y3: f64 = cast(y3).unwrap();
let y4: f64 = cast(y4).unwrap();
let denom = ((t1 - t2) * (y3 - y4)) - ((y1 - y2) * (t3 - t4));
let t_top = (((t1 * y2) - (y1 * t2)) * (t3 - t4)) - ((t1 - t2) * (t3 * y4 - y3 * t4));
let y_top = (((t1 * y2) - (y1 * t2)) * (y3 - y4)) - ((y1 - y2) * (t3 * y4 - y3 * t4));
let t = Duration::from_secs_f64(t_top / denom);
let y: $ty = num_traits::cast(y_top / denom).unwrap();
Sample { time: t, value: y }
}
}
};
}
interpolate_for_num!(i8);
interpolate_for_num!(i16);
interpolate_for_num!(i32);
interpolate_for_num!(i64);
interpolate_for_num!(u8);
interpolate_for_num!(u16);
interpolate_for_num!(u32);
interpolate_for_num!(u64);
interpolate_for_num!(f32);
interpolate_for_num!(f64);

View file

@ -1,6 +1,8 @@
use num_traits::{NumCast, Signed};
use num_traits::Signed;
use super::traits::{LinearInterpolatable, SignalAbs};
use super::interpolation::Linear;
use super::traits::SignalAbs;
use super::{FindIntersectionMethod, InterpolationMethod};
use crate::signals::utils::{apply1, apply2};
use crate::signals::Signal;
@ -18,37 +20,39 @@ where
impl<T> core::ops::Add for &Signal<T>
where
T: core::ops::Add<T, Output = T> + Copy + LinearInterpolatable,
T: core::ops::Add<T, Output = T> + Copy,
Linear: InterpolationMethod<T>,
{
type Output = Signal<T>;
/// Add the given signal with another
fn add(self, rhs: Self) -> Self::Output {
apply2(self, rhs, |lhs, rhs| lhs + rhs)
apply2::<_, _, _, Linear>(self, rhs, |lhs, rhs| lhs + rhs)
}
}
impl<T> core::ops::Mul for &Signal<T>
where
T: core::ops::Mul<T, Output = T> + Copy + LinearInterpolatable,
T: core::ops::Mul<T, Output = T> + Copy,
Linear: InterpolationMethod<T>,
{
type Output = Signal<T>;
/// Multiply the given signal with another
fn mul(self, rhs: Self) -> Self::Output {
apply2(self, rhs, |lhs, rhs| lhs * rhs)
apply2::<_, _, _, Linear>(self, rhs, |lhs, rhs| lhs * rhs)
}
}
impl<T> core::ops::Sub for &Signal<T>
where
T: core::ops::Sub<T, Output = T> + Copy + LinearInterpolatable + PartialOrd + NumCast,
T: core::ops::Sub<T, Output = T> + Copy + PartialOrd,
Linear: InterpolationMethod<T> + FindIntersectionMethod<T>,
{
type Output = Signal<T>;
/// Subtract the given signal with another
fn sub(self, rhs: Self) -> Self::Output {
use super::InterpolationMethod::Linear;
// This has to be manually implemented and cannot use the apply2 functions.
// This is because if we have two signals that cross each other, then there is
// an intermediate point where the two signals are equal. This point must be
@ -60,12 +64,12 @@ where
}
// the union of the sample points in self and other
let sync_points = self.sync_with_intersection(rhs).unwrap();
let sync_points = self.sync_with_intersection::<Linear>(rhs).unwrap();
sync_points
.into_iter()
.map(|t| {
let lhs = self.interpolate_at(t, Linear).unwrap();
let rhs = rhs.interpolate_at(t, Linear).unwrap();
let lhs = self.interpolate_at::<Linear>(t).unwrap();
let rhs = rhs.interpolate_at::<Linear>(t).unwrap();
(t, lhs - rhs)
})
.collect()
@ -74,25 +78,27 @@ where
impl<T> core::ops::Div for &Signal<T>
where
T: core::ops::Div<T, Output = T> + Copy + LinearInterpolatable,
T: core::ops::Div<T, Output = T> + Copy,
Linear: InterpolationMethod<T>,
{
type Output = Signal<T>;
/// Divide the given signal with another
fn div(self, rhs: Self) -> Self::Output {
apply2(self, rhs, |lhs, rhs| lhs / rhs)
apply2::<_, _, _, Linear>(self, rhs, |lhs, rhs| lhs / rhs)
}
}
impl<T> num_traits::Pow<Self> for &Signal<T>
where
T: num_traits::Pow<T, Output = T> + Copy + LinearInterpolatable,
T: num_traits::Pow<T, Output = T> + Copy,
Linear: InterpolationMethod<T>,
{
type Output = Signal<T>;
/// Returns the values in `self` to the power of the values in `other`
fn pow(self, other: Self) -> Self::Output {
apply2(self, other, |lhs, rhs| lhs.pow(rhs))
apply2::<_, _, _, Linear>(self, other, |lhs, rhs| lhs.pow(rhs))
}
}

View file

@ -3,12 +3,13 @@ use core::time::Duration;
use itertools::Itertools;
use super::traits::LinearInterpolatable;
use super::interpolation::Linear;
use super::{InterpolationMethod, Signal};
impl<T> Signal<T>
where
T: Copy + LinearInterpolatable,
T: Copy,
Linear: InterpolationMethod<T>,
{
/// Shift all samples in the signal by `delta` amount to the left.
///
@ -34,7 +35,7 @@ where
if idx > 0 && first_t != &delta {
// The shifted signal will not start at 0, and we have a previous
// index to interpolate from.
let v = self.interpolate_at(delta, InterpolationMethod::Linear).unwrap();
let v = self.interpolate_at::<Linear>(delta).unwrap();
new_samples.push((Duration::ZERO, v));
}
// Shift the rest of the samples

View file

@ -10,155 +10,23 @@ use super::utils::Neighborhood;
use super::{Sample, Signal};
use crate::ArgusResult;
/// Trait for values that are linear interpolatable
pub trait LinearInterpolatable {
/// Compute the linear interpolation of two samples at `time`
/// Trait implemented by interpolation strategies
pub trait InterpolationMethod<T> {
/// Compute the interpolation of two samples at `time`.
///
/// This should assume that the `time` value is between the sample times of `a` and
/// `b`. This should be enforced as an assertion.
fn interpolate_at(a: &Sample<Self>, b: &Sample<Self>, time: Duration) -> Self
where
Self: Sized;
/// Returns `None` if it isn't possible to interpolate at the given time using the
/// given samples.
fn at(a: &Sample<T>, b: &Sample<T>, time: Duration) -> Option<T>;
}
/// Trait implemented by interpolation strategies that allow finding the intersection of
/// two signal segments defined by start and end samples (see [`Neighborhood`]).
pub trait FindIntersectionMethod<T>: InterpolationMethod<T> {
/// Given two signals with two sample points each, find the intersection of the two
/// lines.
fn find_intersection(a: &Neighborhood<Self>, b: &Neighborhood<Self>) -> Sample<Self>
where
Self: Sized;
fn find_intersection(a: &Neighborhood<T>, b: &Neighborhood<T>) -> Sample<T>;
}
impl LinearInterpolatable for bool {
fn interpolate_at(a: &Sample<Self>, b: &Sample<Self>, time: Duration) -> Self
where
Self: Sized,
{
assert!(a.time < time && time < b.time);
// We can't linear interpolate a boolean, so we return the previous.
a.value
}
fn find_intersection(a: &Neighborhood<Self>, b: &Neighborhood<Self>) -> Sample<Self>
where
Self: Sized,
{
let Sample { time: ta1, value: ya1 } = a.first.unwrap();
let Sample { time: ta2, value: ya2 } = a.second.unwrap();
let Sample { time: tb1, value: yb1 } = b.first.unwrap();
let Sample { time: tb2, value: yb2 } = b.second.unwrap();
let left_cmp = ya1.cmp(&yb1);
let right_cmp = ya2.cmp(&yb2);
if left_cmp.is_eq() {
// They already intersect, so we return the inner time-point
if ta1 < tb1 {
Sample { time: tb1, value: yb1 }
} else {
Sample { time: ta1, value: ya1 }
}
} else if right_cmp.is_eq() {
// They intersect at the end, so we return the outer time-point, as that is
// when they become equal.
if ta2 < tb2 {
Sample { time: tb2, value: yb2 }
} else {
Sample { time: ta2, value: ya2 }
}
} else {
// The switched, so the one that switched earlier will intersect with the
// other.
// So, we find the one that has a lower time point, i.e., the inner one.
if ta2 < tb2 {
Sample { time: ta2, value: ya2 }
} else {
Sample { time: tb2, value: yb2 }
}
}
}
}
macro_rules! interpolate_for_num {
($ty:ty) => {
impl LinearInterpolatable for $ty {
fn interpolate_at(first: &Sample<Self>, second: &Sample<Self>, time: Duration) -> Self
where
Self: Sized,
{
use num_traits::cast;
// We will need to cast the samples to f64 values (along with the time
// window) to be able to interpolate correctly.
// TODO(anand): Verify this works.
let t1 = first.time.as_secs_f64();
let t2 = second.time.as_secs_f64();
let at = time.as_secs_f64();
assert!((t1..=t2).contains(&at));
// We need to do stable linear interpolation
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p0811r3.html
let a: f64 = cast(first.value).unwrap();
let b: f64 = cast(second.value).unwrap();
// Set t to a value in [0, 1]
let t = (at - t1) / (t2 - t1);
assert!((0.0..=1.0).contains(&t));
let val = if (a <= 0.0 && b >= 0.0) || (a >= 0.0 && b <= 0.0) {
t * b + (1.0 - t) * a
} else if t == 1.0 {
b
} else {
a + t * (b - a)
};
cast(val).unwrap()
}
fn find_intersection(a: &Neighborhood<Self>, b: &Neighborhood<Self>) -> Sample<Self>
where
Self: Sized,
{
// https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection#Given_two_points_on_each_line
use num_traits::cast;
let Sample { time: t1, value: y1 } = a.first.unwrap();
let Sample { time: t2, value: y2 } = a.second.unwrap();
let Sample { time: t3, value: y3 } = b.first.unwrap();
let Sample { time: t4, value: y4 } = b.second.unwrap();
let t1 = t1.as_secs_f64();
let t2 = t2.as_secs_f64();
let t3 = t3.as_secs_f64();
let t4 = t4.as_secs_f64();
let y1: f64 = cast(y1).unwrap();
let y2: f64 = cast(y2).unwrap();
let y3: f64 = cast(y3).unwrap();
let y4: f64 = cast(y4).unwrap();
let denom = ((t1 - t2) * (y3 - y4)) - ((y1 - y2) * (t3 - t4));
let t_top = (((t1 * y2) - (y1 * t2)) * (t3 - t4)) - ((t1 - t2) * (t3 * y4 - y3 * t4));
let y_top = (((t1 * y2) - (y1 * t2)) * (y3 - y4)) - ((y1 - y2) * (t3 * y4 - y3 * t4));
let t = Duration::from_secs_f64(t_top / denom);
let y: Self = cast(y_top / denom).unwrap();
Sample { time: t, value: y }
}
}
};
}
interpolate_for_num!(i8);
interpolate_for_num!(i16);
interpolate_for_num!(i32);
interpolate_for_num!(i64);
interpolate_for_num!(u8);
interpolate_for_num!(u16);
interpolate_for_num!(u32);
interpolate_for_num!(u64);
interpolate_for_num!(f32);
interpolate_for_num!(f64);
/// Simple trait to be used as a trait object for [`Signal<T>`] types.
///
/// This is mainly for external libraries to use for trait objects and downcasting to

View file

@ -8,7 +8,6 @@ use core::ops::{Bound, RangeBounds};
use core::time::Duration;
use std::iter::zip;
use super::traits::LinearInterpolatable;
use super::{InterpolationMethod, Sample, Signal};
/// The neighborhood around a signal such that the time `at` is between the `first` and
@ -42,11 +41,12 @@ where
}
#[inline]
pub fn apply2<'a, T, U, F>(lhs: &'a Signal<T>, rhs: &'a Signal<T>, op: F) -> Signal<U>
pub fn apply2<'a, T, U, F, Interp>(lhs: &'a Signal<T>, rhs: &'a Signal<T>, op: F) -> Signal<U>
where
T: Copy + LinearInterpolatable,
T: Copy,
U: Copy,
F: Fn(T, T) -> U,
Interp: InterpolationMethod<T>,
{
use Signal::*;
// If either of the signals are empty, we return an empty signal.
@ -67,8 +67,8 @@ where
time_points
.into_iter()
.map(|t| {
let v1 = lhs.interpolate_at(*t, InterpolationMethod::Linear).unwrap();
let v2 = rhs.interpolate_at(*t, InterpolationMethod::Linear).unwrap();
let v1 = lhs.interpolate_at::<Interp>(*t).unwrap();
let v2 = rhs.interpolate_at::<Interp>(*t).unwrap();
(*t, op(v1, v2))
})
.collect()