diff --git a/transformation/matcher.py b/transformation/matcher.py index 319b05e..3b38932 100644 --- a/transformation/matcher.py +++ b/transformation/matcher.py @@ -12,7 +12,7 @@ import itertools import re import functools -from util.timer import Timer +from util.timer import Timer, counted class _is_edge: def __repr__(self): @@ -323,6 +323,27 @@ def match_od(state, host_m, host_mm, pattern_m, pattern_mm, pivot={}): g_names, guest = model_to_graph(state, pattern_m, pattern_mm, _filter=is_matchable) + # precompute the candidates for every guest vertex: + guest_to_host_candidate_vtxs = {} + vtxs_of_host_type = {} + + for g_vtx in guest.vtxs: + object_node = g_vtx.node_id + if hasattr(g_vtx, 'typ'): + orig_class_node = ramify.get_original_type(bottom, g_vtx.typ) + orig_class_name = odapi.get_name(orig_class_node) + if orig_class_name in vtxs_of_host_type: + cands = vtxs_of_host_type[orig_class_name] + else: + cands = vtxs_of_host_type[orig_class_name] = len(odapi.get_all_instances(orig_class_name, include_subtypes=True)) + else: + cands = len(host.vtxs) + guest_to_host_candidate_vtxs[g_vtx] = cands + + # print(guest_to_host_candidate_vtxs) + + + # transform 'pivot' into something VF2 understands graph_pivot = { g_names[guest_name] : h_names[host_name] for guest_name, host_name in pivot.items() @@ -376,7 +397,7 @@ def match_od(state, host_m, host_mm, pattern_m, pattern_mm, pivot={}): compare = RAMCompare(bottom, services_od.OD(host_mm, host_m, state)) - matcher = MatcherVF2(host, guest, compare) + matcher = MatcherVF2(host, guest, compare, guest_to_host_candidate_vtxs) for m in matcher.match(graph_pivot): # Convert mapping name_mapping = {} diff --git a/transformation/vf2.py b/transformation/vf2.py index 987b084..4aec618 100644 --- a/transformation/vf2.py +++ b/transformation/vf2.py @@ -4,7 +4,7 @@ import itertools -from util.timer import Timer +from util.timer import Timer, counted # like finding the 'strongly connected componenets', but edges are navigable in any direction def find_connected_components(graph): @@ -81,6 +81,11 @@ class MatcherState: state = MatcherState() state.h_unmatched_vtxs = [vtx for vtx in host.vtxs if vtx not in pivot.values()] state.g_unmatched_vtxs = [vtx for vtx in guest.vtxs if vtx not in pivot.keys()] + # if guest_to_host_candidates != None: + # state.g_unmatched_vtxs.sort( + # # performance thingy: + # # try to match guest vtxs with few candidates first (fail early!): + # key=lambda guest_vtx: guest_to_host_candidates.get(guest_vtx, 0)) state.mapping_vtxs = pivot state.r_mapping_vtxs = { v: k for k,v in state.mapping_vtxs.items() } return state @@ -129,27 +134,41 @@ class MatcherState: # return self.make_hashable().__repr__() return "VTXS: "+self.mapping_vtxs.__repr__()+"\nEDGES: "+self.mapping_edges.__repr__() + class MatcherVF2: # Guest is the pattern - def __init__(self, host, guest, compare_fn): + def __init__(self, host, guest, compare_fn, guest_to_host_candidates=None): self.host = host self.guest = guest self.compare_fn = compare_fn + # map guest vertex to number of candidate vertices in host graph: + if guest_to_host_candidates != None: + self.guest_to_host_candidates = guest_to_host_candidates + else: + # atttempt to match every guest vertex with every host vertex (slow!) + self.guest_to_host_candidates = { g_vtx : len(host.vtxs) for g_vtx in guest.vtxs } + # with Timer("find_connected_components - guest"): self.guest_vtx_to_component, self.guest_component_to_vtxs = find_connected_components(guest) - # print("number of guest connected components:", len(self.guest_component_to_vtxs)) + for component in self.guest_component_to_vtxs: + pass + component.sort(key=lambda guest_vtx: guest_to_host_candidates[guest_vtx]) + if len(self.guest_component_to_vtxs) > 1: + print("warning: pattern has multiple components:", len(self.guest_component_to_vtxs)) def match(self, pivot={}): yield from self._match( state=MatcherState.make_initial(self.host, self.guest, pivot), already_visited=set()) - + # @counted def _match(self, state, already_visited, indent=0): # input() + num_matches = 0 + def print_debug(*args): pass # print(" "*indent, *args) # uncomment to see a trace of the matching process @@ -161,7 +180,7 @@ class MatcherVF2: if hashable in already_visited: print_debug(" SKIP - ALREADY VISITED") # print_debug(" ", hashable) - return + return 0 # print_debug(" ", [hash(a) for a in already_visited]) # print_debug(" ADD STATE") # print_debug(" ", hash(hashable)) @@ -173,7 +192,7 @@ class MatcherVF2: print_debug(" ", state.mapping_vtxs) print_debug(" ", state.mapping_edges) yield state - return + return 1 def read_edge(edge, direction): if direction == "outgoing": @@ -184,6 +203,7 @@ class MatcherVF2: raise Exception("wtf!") def attempt_grow(direction, indent): + num_matches = 0 for g_matched_vtx, h_matched_vtx in state.mapping_vtxs.items(): print_debug('attempt_grow', direction) for g_candidate_edge in getattr(g_matched_vtx, direction): @@ -204,46 +224,48 @@ class MatcherVF2: print_debug('grow edge', g_candidate_edge, ':', h_candidate_edge, id(g_candidate_edge), id(h_candidate_edge)) new_state = state.grow_edge(h_candidate_edge, g_candidate_edge) h_candidate_vtx = read_edge(h_candidate_edge, direction) - yield from attempt_match_vtxs( + num_matches += yield from attempt_match_vtxs( new_state, g_candidate_vtx, h_candidate_vtx, indent+1) print_debug('backtrack edge', g_candidate_edge, ':', h_candidate_edge, id(g_candidate_edge), id(h_candidate_edge)) + return num_matches def attempt_match_vtxs(state, g_candidate_vtx, h_candidate_vtx, indent): print_debug('attempt_match_vtxs') if g_candidate_vtx in state.mapping_vtxs: if state.mapping_vtxs[g_candidate_vtx] != h_candidate_vtx: print_debug(" nope, guest already mapped (mismatch)") - return # guest vtx is already mapped but doesn't match host vtx + return 0 # guest vtx is already mapped but doesn't match host vtx if h_candidate_vtx in state.r_mapping_vtxs: if state.r_mapping_vtxs[h_candidate_vtx] != g_candidate_vtx: print_debug(" nope, host already mapped (mismatch)") - return # host vtx is already mapped but doesn't match guest vtx + return 0 # host vtx is already mapped but doesn't match guest vtx g_outdegree = len(g_candidate_vtx.outgoing) h_outdegree = len(h_candidate_vtx.outgoing) if g_outdegree > h_outdegree: print_debug(" nope, outdegree") - return + return 0 g_indegree = len(g_candidate_vtx.incoming) h_indegree = len(h_candidate_vtx.incoming) if g_indegree > h_indegree: print_debug(" nope, indegree") - return + return 0 if not self.compare_fn(g_candidate_vtx, h_candidate_vtx): print_debug(" nope, bad compare") - return + return 0 new_state = state.grow_vtx( h_candidate_vtx, g_candidate_vtx) print_debug('grow vtx', g_candidate_vtx, ':', h_candidate_vtx, id(g_candidate_vtx), id(h_candidate_vtx)) - yield from self._match(new_state, already_visited, indent+1) + num_matches = yield from self._match(new_state, already_visited, indent+1) print_debug('backtrack vtx', g_candidate_vtx, ':', h_candidate_vtx, id(g_candidate_vtx), id(h_candidate_vtx)) + return num_matches print_debug('preferred...') - yield from attempt_grow('outgoing', indent+1) - yield from attempt_grow('incoming', indent+1) + num_matches += yield from attempt_grow('outgoing', indent+1) + num_matches += yield from attempt_grow('incoming', indent+1) print_debug('least preferred...') if state.boundary != None: @@ -257,15 +279,25 @@ class MatcherVF2: for g_candidate_vtxs in guest_components_to_try: for g_candidate_vtx in g_candidate_vtxs: + g_vtx_matches = 0 + g_vtx_max = self.guest_to_host_candidates[g_candidate_vtx] + # print(' guest vtx has', g_vtx_max, ' host candidates') if g_candidate_vtx in state.mapping_vtxs: print_debug("skip (already matched)", g_candidate_vtx) continue for h_candidate_vtx in state.h_unmatched_vtxs: - yield from attempt_match_vtxs(state, g_candidate_vtx, h_candidate_vtx, indent+1) + N = yield from attempt_match_vtxs(state, g_candidate_vtx, h_candidate_vtx, indent+1) + g_vtx_matches += N > 0 + num_matches += N + if g_vtx_matches == g_vtx_max: + # print("EARLY STOP") + break # found all matches if indent == 0: print_debug('visited', len(already_visited), 'states total') + return num_matches + # demo time... if __name__ == "__main__": host = Graph()