muMLE/transformation/schedule/schedule_lib/data_node.py

101 lines
3.4 KiB
Python

from abc import abstractmethod
from typing import Any, Generator, List, override
from jinja2 import Template
from .data import Data
from .funcs import generate_dot_edge
from .node import Node
class DataNodeState:
def __init__(self) -> None:
super().__init__()
class DataNode(Node):
def __init__(self) -> None:
super().__init__()
self.eventsub: dict[str, list[tuple[DataNode, str]]] = {
gate: [] for gate in self.get_data_output_gates()
}
self.data_out: dict[str, Data] = {
name: Data(self) for name in self.get_data_output_gates()
}
self.data_in: dict[str, Data | None] = {
name: None for name in self.get_data_input_gates()
}
@staticmethod
def get_data_input_gates() -> List[str]:
return ["in"]
@staticmethod
def get_data_output_gates() -> List[str]:
return ["out"]
@override
def generate_stack_frame(self, exec_id: int) -> None:
super().generate_stack_frame(exec_id)
for d in self.data_out.values():
d.generate_stack_frame(exec_id)
@override
def delete_stack_frame(self, exec_id: int) -> None:
super().delete_stack_frame(exec_id)
for d in self.data_out.values():
d.delete_stack_frame(exec_id)
def connect_data(
self, data_node: "DataNode", from_gate: str, to_gate: str, eventsub=True
) -> None:
if from_gate not in self.get_data_output_gates():
raise Exception(f"from_gate {from_gate} is not a valid port")
if to_gate not in data_node.get_data_input_gates():
raise Exception(f"to_gate {to_gate} is not a valid port")
data_node.data_in[to_gate] = self.data_out[from_gate]
if eventsub:
self.eventsub[from_gate].append((data_node, to_gate))
def store_data(self, exec_id, data_gen: Generator, port: str, n: int) -> None:
self.data_out[port].store_data(exec_id, data_gen, n)
for sub, gate in self.eventsub[port]:
sub.input_event(gate, exec_id)
def get_input_data(self, gate: str, exec_id: int) -> list[dict[Any, Any]]:
data = self.data_in[gate]
if data is None:
return [{}]
return data.get_data(exec_id)
@abstractmethod
def input_event(self, gate: str, exec_id: int) -> None:
for sub, gate_sub in self.eventsub[gate]:
sub.input_event(gate_sub, exec_id)
def generate_dot(
self, nodes: List[str], edges: List[str], visited: set[int], template: Template
) -> None:
for port, data in self.data_in.items():
if data is not None:
source = data.get_parent()
generate_dot_edge(
source,
self,
edges,
template,
kwargs={
"prefix": "d",
"from_gate": [
port
for port, value in source.data_out.items()
if value == data
][0],
"to_gate": port,
"color": "green",
},
)
data.get_parent().generate_dot(nodes, edges, visited, template)
for gate_form, subs in self.eventsub.items():
for sub, gate in subs:
sub.generate_dot(nodes, edges, visited, template)