add profiler to give performance insights, seems like matching is usually the bottleneck

This commit is contained in:
Joeri Exelmans 2024-11-22 16:25:08 +01:00
parent 5962a476c0
commit 4fe7e19714
3 changed files with 130 additions and 54 deletions

View file

@ -206,6 +206,11 @@ def model_to_graph(state: State, model: UUID, metamodel: UUID,
return names, graph
class _No_Matched(Exception):
pass
def _cannot_call_matched(_):
raise _No_Matched()
# This function returns a Generator of matches.
# The idea is that the user can iterate over the match set, lazily generating it: if only interested in the first match, the entire match set doesn't have to be generated.
def match_od(state, host_m, host_mm, pattern_m, pattern_mm, pivot={}):
@ -265,10 +270,21 @@ def match_od(state, host_m, host_mm, pattern_m, pattern_mm, pivot={}):
python_code = services_od.read_primitive_value(self.bottom, g_vtx.node_id, pattern_mm)[0]
try:
# Try to execute code, but if the `matched` API-function is called, we fail.
with Timer(f'EVAL condition {g_vtx.name}'):
ok = exec_then_eval(python_code,
_globals={
**bind_api_readonly(odapi),
'matched': _cannot_call_matched,
},
_locals={'this': h_vtx.node_id})
self.conditions_to_check.pop(g_vtx.name, None)
return ok
except _No_Matched:
# The code made a call to the `matched`-function.
self.conditions_to_check[g_vtx.name] = python_code
# self.conditions_to_check.append((python_code, h_vtx.name, g_vtx.name))
return True # do be determined later, if it's actually a match
return True # to be determined later, if it's actually a match
if g_vtx.value == None:
return h_vtx.value == None
@ -339,18 +355,21 @@ def match_od(state, host_m, host_mm, pattern_m, pattern_mm, pivot={}):
except KeyError:
continue
host_node = odapi.get(host_name)
with Timer(f'EVAL condition {pattern_name}'):
if not check(python_code, {'this': host_node}):
return False
for python_code, pattern_el_name in obj_conditions:
if pattern_el_name == None:
# GlobalCondition
with Timer(f'EVAL all global conditions'):
if not check(python_code, {}):
return False
else:
# object-lvl condition
host_el_name = name_mapping[pattern_el_name]
host_node = odapi.get(host_el_name)
with Timer(f'EVAL local condition {pattern_el_name}'):
if not check(python_code, {'this': host_node}):
return False
return True

View file

@ -8,6 +8,7 @@ from concrete_syntax.common import indent
from transformation.matcher import match_od
from transformation.rewriter import rewrite
from transformation.cloner import clone_od
from util.timer import Timer
class Rule:
def __init__(self, nacs: list[UUID], lhs: UUID, rhs: UUID):
@ -18,6 +19,9 @@ class Rule:
PP = pprint.PrettyPrinter(depth=4)
class _NAC_MATCHED(Exception):
pass
# Helper for executing NAC/LHS/RHS-type rules
class RuleMatcherRewriter:
def __init__(self, state, mm: UUID, mm_ramified: UUID):
@ -36,11 +40,18 @@ class RuleMatcherRewriter:
try:
# First we iterate over LHS-matches:
for i, lhs_match in enumerate(lhs_matcher):
# for i, lhs_match in enumerate(lhs_matcher):
x=0
while True:
try:
with Timer(f"MATCH LHS {rule_name}"):
lhs_match = lhs_matcher.__next__()
x += 1
nac_matched = False
for nac in nacs:
try:
for i_nac, nac in enumerate(nacs):
# For every LHS-match, we see if there is a NAC-match:
nac_matcher = match_od(self.state,
host_m=m,
@ -50,7 +61,16 @@ class RuleMatcherRewriter:
pivot=lhs_match) # try to "grow" LHS-match with NAC-match
try:
for j, nac_match in enumerate(nac_matcher):
# for nac_match in nac_matcher:
while True:
try:
with Timer(f"MATCH NAC{i_nac} {rule_name}"):
nac_match = nac_matcher.__next__()
raise _NAC_MATCHED()
except StopIteration:
break # no more nac-matches
# The NAC has at least one match
# (there could be more, but we know enough, so let's not waste CPU/MEM resources and proceed to next LHS match)
nac_matched = True
@ -60,13 +80,13 @@ class RuleMatcherRewriter:
# Decorate exception with some context, to help with debugging
e.add_note(f"while matching NAC of '{rule_name}'")
raise
except _NAC_MATCHED:
continue # continue with next LHS-match
if nac_matched:
break
if not nac_matched:
# There were no NAC matches -> yield LHS-match!
yield lhs_match
except StopIteration:
break # no more lhs-matches
except Exception as e:
@ -102,9 +122,17 @@ class ActionGenerator:
def __call__(self, od: ODAPI):
at_least_one_match = False
for rule_name, rule in self.rule_dict.items():
for lhs_match in self.matcher_rewriter.match_rule(od.m, rule.lhs, rule.nacs, rule_name):
match_iterator = self.matcher_rewriter.match_rule(od.m, rule.lhs, rule.nacs, rule_name)
x = 0
while True:
try:
# if True:
with Timer(f"MATCH RULE {rule_name}"):
lhs_match = match_iterator.__next__()
x += 1
# We got a match!
def do_action(od, rule, lhs_match, rule_name):
with Timer(f"EXEC RHS {rule_name}"):
new_m, rhs_match = self.matcher_rewriter.exec_rule(od.m, rule.lhs, rule.rhs, lhs_match, rule_name)
msgs = [f"executed rule '{rule_name}'\n" + indent(PP.pformat(rhs_match), 6)]
return (ODAPI(od.state, new_m, od.mm), msgs)
@ -113,6 +141,8 @@ class ActionGenerator:
functools.partial(do_action, od, rule, lhs_match, rule_name) # the action itself (as a callback)
)
at_least_one_match = True
except StopIteration:
break
return at_least_one_match
# Given a list of actions (in high -> low priority), will always yield the highest priority enabled actions.

View file

@ -1,10 +1,37 @@
import time
import os
class Timer:
if "MUMLE_PROFILER" in os.environ:
import time
import atexit
timings = {}
class Timer:
def __init__(self, text):
self.text = text
def __enter__(self):
self.start_time = time.perf_counter_ns()
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.perf_counter_ns()
print(self.text, (self.end_time - self.start_time)/1000000, "ms")
duration = self.end_time - self.start_time
existing_timing = timings.get(self.text, 0)
timings[self.text] = existing_timing + duration
def __print_timings():
if len(timings)>0:
print(f'Timings:')
tuples = [(text,duration) for text, duration in timings.items()]
tuples.sort(key=lambda tup: -tup[1])
for text, duration in tuples:
print(f' {text} {round(duration/1000000)} ms')
atexit.register(__print_timings)
else:
class Timer:
def __init__(self, text):
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass