test(pyargus): add tests comparing semantics against metric-temporal-logic

This commit is contained in:
Anand Balakrishnan 2023-10-15 14:39:09 -07:00
parent 192bb20380
commit 77a9106e8b
No known key found for this signature in database
6 changed files with 135 additions and 18 deletions

View file

@ -136,7 +136,7 @@ def mypy(session: nox.Session):
def tests(session: nox.Session):
session.conda_install("pytest", "hypothesis", "lark")
session.env.update(ENV)
session.install("./pyargus")
session.install("./pyargus[test]")
try:
session.run(
"cargo", "test", "--workspace", "--exclude", "pyargus", external=True
@ -181,6 +181,8 @@ def coverage(session: nox.Session):
"develop",
"-m",
"./pyargus/Cargo.toml",
"-E",
"test",
silent=True,
)
try:

View file

@ -35,24 +35,34 @@ def gen_samples(
min_size: int,
max_size: int,
dtype_: Union[Type[AllowedDtype], dtype],
) -> List[Tuple[float, AllowedDtype]]:
n_lists: int = 1,
) -> Union[List[Tuple[float, AllowedDtype]], List[List[Tuple[float, AllowedDtype]]]]:
"""
Generate arbitrary samples for a signal where the time stamps are strictly
monotonically increasing
:param n_lists: used to generate multiple sample lists with the same time domain. This is used for testing against
`metric-temporal-logic` as it doesn't check for non-overlapping domains.
"""
elements = gen_element_fn(dtype_)
values = draw(st.lists(elements, min_size=min_size, max_size=max_size))
xs = draw(
n_lists = max(1, n_lists)
timestamps = draw(
st.lists(
st.integers(min_value=0, max_value=2**32 - 1),
unique=True,
min_size=len(values),
max_size=len(values),
)
.map(lambda t: map(float, sorted(set(t))))
.map(lambda t: list(zip(t, values)))
min_size=min_size,
max_size=max_size,
).map(lambda t: list(map(float, sorted(set(t)))))
)
return xs
elements = gen_element_fn(dtype_)
sample_lists = [
draw(st.lists(elements, min_size=len(timestamps), max_size=len(timestamps)).map(lambda xs: list(zip(timestamps, xs))))
for _ in range(n_lists)
]
if n_lists == 1:
return sample_lists[0]
else:
return sample_lists
def empty_signal(dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[argus.Signal]:

View file

@ -35,7 +35,13 @@ dev = [
"black",
]
test = ["pytest", "coverage", "hypothesis[lark]"]
test = [
"pytest",
"coverage",
"hypothesis[lark]",
"metric-temporal-logic",
"rtamt",
]
[build-system]
requires = ["maturin>=1.0,<2.0"]
@ -54,6 +60,11 @@ packages = ["argus"]
# ignore_missing_imports = true
show_error_codes = true
[[tool.mypy.overrides]]
module = "mtl"
ignore_missing_imports = true
[tool.ruff]
line-length = 127
select = ["E", "F", "W", "N", "B", "ANN", "PYI"]

View file

@ -4,11 +4,12 @@ mod signals;
use argus::Error as ArgusError;
use ariadne::Source;
use expr::PyExpr;
use pyo3::exceptions::{PyKeyError, PyRuntimeError, PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyFloat, PyInt, PyType};
use crate::expr::{PyBoolExpr, PyNumExpr};
#[derive(derive_more::From)]
struct PyArgusError(ArgusError);
@ -73,10 +74,14 @@ impl DType {
/// Parse a string expression into a concrete Argus expression.
#[pyfunction]
fn parse_expr(expr_str: &str) -> PyResult<PyObject> {
use argus::expr::Expr;
use ariadne::{Color, Label, Report, ReportKind};
match argus::parse_str(expr_str) {
Ok(expr) => Python::with_gil(|py| PyExpr::from_expr(py, expr)),
Ok(expr) => Python::with_gil(|py| match expr {
Expr::Bool(e) => PyBoolExpr::from_expr(py, e),
Expr::Num(e) => PyNumExpr::from_expr(py, e),
}),
Err(errs) => {
let mut buf = Vec::new();
{

View file

@ -0,0 +1,89 @@
from typing import List, Tuple
import hypothesis.strategies as st
import mtl
from hypothesis import given
import argus
from argus.test_utils.signals_gen import gen_samples
@given(
sample_lists=gen_samples(min_size=3, max_size=50, dtype_=bool, n_lists=2),
spec=st.one_of(
[
st.just(spec)
for spec in [
"a",
"~a",
"(a & b)",
"(a | b)",
"(a -> b)",
"(a <-> b)",
"(a ^ b)",
]
]
),
)
def test_boolean_propositional_expr(
sample_lists: List[List[Tuple[float, bool]]],
spec: str,
) -> None:
mtl_spec = mtl.parse(spec)
argus_spec = argus.parse_expr(spec)
assert isinstance(argus_spec, argus.BoolExpr)
a, b = sample_lists
mtl_data = dict(a=a, b=b)
argus_data = argus.Trace(
dict(
a=argus.BoolSignal.from_samples(a, interpolation_method="constant"),
b=argus.BoolSignal.from_samples(b, interpolation_method="constant"),
)
)
mtl_rob = mtl_spec(mtl_data, quantitative=False)
argus_rob = argus.eval_bool_semantics(argus_spec, argus_data)
assert mtl_rob == argus_rob.at(0), f"{argus_rob=}"
@given(
sample_lists=gen_samples(min_size=3, max_size=50, dtype_=bool, n_lists=2),
spec=st.one_of(
[
st.just(spec)
for spec in [
"F a",
"G b",
"(G(a & b))",
"(F(a | b))",
"G(a -> F[0,2] b)",
"(a U b)",
"(a U[0,2] b)",
]
]
),
)
def test_boolean_temporal_expr(
sample_lists: List[List[Tuple[float, bool]]],
spec: str,
) -> None:
mtl_spec = mtl.parse(spec)
argus_spec = argus.parse_expr(spec)
assert isinstance(argus_spec, argus.BoolExpr)
a = sample_lists[0]
b = sample_lists[1]
mtl_data = dict(a=a, b=b)
argus_data = argus.Trace(
dict(
a=argus.BoolSignal.from_samples(a, interpolation_method="constant"),
b=argus.BoolSignal.from_samples(b, interpolation_method="constant"),
)
)
mtl_rob = mtl_spec(mtl_data, quantitative=False)
argus_rob = argus.eval_bool_semantics(argus_spec, argus_data)
assert mtl_rob == argus_rob.at(0), f"{argus_rob=}"

View file

@ -29,7 +29,7 @@ def test_correct_constant_signals(data: st.DataObject) -> None:
def test_correctly_create_signals(data: st.DataObject) -> None:
dtype_ = data.draw(gen_dtype())
xs = data.draw(gen_samples(min_size=0, max_size=100, dtype_=dtype_))
signal = sampled_signal(xs, dtype_)
signal = sampled_signal(xs, dtype_) # type: ignore
assert isinstance(signal, argus.Signal)
if len(xs) > 0:
expected_start_time = xs[0][0]
@ -46,7 +46,7 @@ def test_correctly_create_signals(data: st.DataObject) -> None:
a = data.draw(draw_index(xs))
assert a < len(xs)
at, expected_val = xs[a]
actual_val = signal.at(at)
actual_val = signal.at(at) # type: ignore
assert actual_val is not None
assert actual_val == expected_val
@ -79,10 +79,10 @@ def test_signal_create_should_fail(data: st.DataObject) -> None:
assert a < len(xs)
assert b < len(xs)
# Swap two indices in the samples
xs[b], xs[a] = xs[a], xs[b]
xs[b], xs[a] = xs[a], xs[b] # type: ignore
with pytest.raises(RuntimeError, match=r"trying to create a non-monotonically signal.+"):
_ = sampled_signal(xs, dtype_)
_ = sampled_signal(xs, dtype_) # type: ignore
@given(st.data())