Added an eval_context_decorator to allow user defined functions in rules

This commit is contained in:
robbe 2025-06-27 12:20:00 +02:00
parent e4ea9a0410
commit ec42f74960

View file

@ -15,6 +15,30 @@ from api.od import ODAPI, bind_api_readonly
import functools
def eval_context_decorator(func):
"""
Used to mark functions that can be called inside the evaluation context.
Base functions are mapped into the function, as well as the evaluation context.
This happens at runtime so typechecking will not be happy.
Important: Using the same name in the evaluation context as the function name
will lead to naming conflicts with the function as priority, resulting in missing argument errors.
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from api.od_stub import *
...
Use this to make partially fix the typechecking.
Optionally, define a stub for your own evaluation context and include it.
"""
def wrapper(*args, api_context, eval_context, **kwargs):
for key, value in api_context.items():
func.__globals__[key] = value
for key, value in eval_context.items():
func.__globals__[key] = value
return func(*args, **kwargs)
return wrapper
def render_conformance_check_result(error_list):
if len(error_list) == 0:
return "CONFORM"
@ -25,7 +49,7 @@ def render_conformance_check_result(error_list):
class Conformance:
# Parameter 'constraint_check_subtypes': whether to check local type-level constraints also on subtypes.
def __init__(self, state: State, model: UUID, type_model: UUID, constraint_check_subtypes=True):
def __init__(self, state: State, model: UUID, type_model: UUID, constraint_check_subtypes=True, *, eval_context = None):
self.state = state
self.bottom = Bottom(state)
self.model = model
@ -51,6 +75,9 @@ class Conformance:
self.structures = {}
self.candidates = {}
# add user defined functions to constraints
self.eval_context = eval_context if eval_context else {}
def check_nominal(self, *, log=False):
"""
@ -248,6 +275,13 @@ class Conformance:
raise Exception(f"{description} evaluation result should be boolean or string! Instead got {result}")
# local constraints
_api_context = bind_api_readonly(self.odapi)
_global_binds = {**_api_context}
_eval_context = {**self.eval_context}
for key, code in _eval_context.items():
_f = functools.partial(code, **{"api_context" :_api_context, "eval_context":_eval_context})
_global_binds[key] = _f
_eval_context[key] = _f
for type_name in self.bottom.read_keys(self.type_model):
code = get_code(type_name)
if code != None:
@ -256,7 +290,7 @@ class Conformance:
description = f"Local constraint of \"{type_name}\" in \"{obj_name}\""
# print(description)
try:
result = exec_then_eval(code, _globals=bind_api_readonly(self.odapi), _locals={'this': obj_id}) # may raise
result = exec_then_eval(code, _globals=_global_binds, _locals={'this': obj_id}) # may raise
check_result(result, description)
except:
errors.append(f"Runtime error during evaluation of {description}:\n{indent(traceback.format_exc(), 6)}")
@ -278,7 +312,7 @@ class Conformance:
if code != None:
description = f"Global constraint \"{tm_name}\""
try:
result = exec_then_eval(code, _globals=bind_api_readonly(self.odapi)) # may raise
result = exec_then_eval(code, _globals=_global_binds) # may raise
check_result(result, description)
except:
errors.append(f"Runtime error during evaluation of {description}:\n{indent(traceback.format_exc(), 6)}")