diff --git a/doc/create_model.rst b/doc/create_model.rst index bec05b0..51a9814 100644 --- a/doc/create_model.rst +++ b/doc/create_model.rst @@ -241,6 +241,16 @@ In terms of computation and inputs, ``advect_model`` is equivalent to the ``advect_model_raw`` instance created above ; it is just organized differently. +Avoiding cycles in the model +---------------------------- +Often, a process involves updating a variable, that is used by other processes in the model. +This may result in a cycle being detected, and the model not able to run. The +process order is created based on variables with ``intent='out'``. Therefore, any variable +that is created with ``intent='inout'`` will be set last in the calculation order. +Any process that uses this variable as an input, will use the veriable from the previous timestep. +For example, the ``u`` variable in ``ProfileU`` has ``intent='inout'``, and is therefore last in the calculation order. + + Update existing models ---------------------- diff --git a/doc/framework.rst b/doc/framework.rst index 2c1a880..9508059 100644 --- a/doc/framework.rst +++ b/doc/framework.rst @@ -273,7 +273,9 @@ in their computation. In a model, the processes and their dependencies together form the nodes and the edges of a Directed Acyclic Graph (DAG). The graph topology is fully determined by the ``intent`` set for each variable -or foreign variable declared in each process. An ordering that is +or foreign variable declared in each process. That is, a process depends on +another process if and only if the 'parent' process has an ``intent='out'`` +variable, that is used by the 'child' process. An ordering that is computationally consistent can then be obtained using topological sorting. This is done at Model object creation. The same ordering is used at every stage of a model run. diff --git a/xsimlab/dot.py b/xsimlab/dot.py index 409494b..55a5930 100644 --- a/xsimlab/dot.py +++ b/xsimlab/dot.py @@ -43,6 +43,7 @@ INPUT_EDGE_ATTRS = {"arrowhead": "none", "color": "#b49434"} VAR_NODE_ATTRS = {"shape": "box", "color": "#555555", "fontcolor": "#555555"} VAR_EDGE_ATTRS = {"arrowhead": "none", "color": "#555555"} +INOUT_EDGE_ATTRS = {"color": "#000000", "style": "dashed"} def _hash_variable(var): @@ -98,12 +99,46 @@ def _add_var(self, var, p_name): if var_intent == VarIntent.OUT: edge_attrs.update({"arrowhead": "empty"}) edge_ends = p_name, var_key + elif var_intent == VarIntent.INOUT: + edge_attrs.update({"arrowhead": "empty"}) + edge_ends = p_name, var_key else: edge_ends = var_key, p_name self.g.node(var_key, label=var.name, **node_attrs) self.g.edge(*edge_ends, weight="200", **edge_attrs) + def add_inout_arrows(self): + for p_name, p_obj in self.model._processes.items(): + p_cls = type(p_obj) + for var_name, var in variables_dict(p_cls).items(): + # test if the variable is inout + if ( + var.metadata["intent"] == VarIntent.INOUT + and var.metadata["var_type"] != VarType.FOREIGN + ): + target_keys = _get_target_keys(p_obj, var_name) + + # now again cycle through all processes to see if there is a variable with the same reference + for p2_name, p2_obj in self.model._processes.items(): + p2_cls = type(p2_obj) + + # skip this if it is a dependent process or the process itself + if ( + p_name in self.model.dependent_processes[p2_name] + or p_name == p2_name + ): + continue + + for var2_name, var2 in variables_dict(p2_cls).items(): + # if the variable is + target2_keys = _get_target_keys(p2_obj, var2_name) + if len(set(target_keys) & set(target2_keys)): + edge_ends = p_name, p2_name + self.g.edge( + *edge_ends, weight="200", **INOUT_EDGE_ATTRS + ) + def add_inputs(self): for p_name, var_name in self.model._input_vars: p_cls = type(self.model[p_name]) @@ -146,6 +181,7 @@ def to_graphviz( show_only_variable=None, show_inputs=False, show_variables=False, + show_inout_arrows=True, graph_attr={}, **kwargs, ): @@ -167,6 +203,9 @@ def to_graphviz( elif show_inputs: builder.add_inputs() + elif show_inout_arrows: + builder.add_inout_arrows() + return builder.get_graph() @@ -211,6 +250,7 @@ def dot_graph( show_only_variable=None, show_inputs=False, show_variables=False, + show_inout_arrows=True, **kwargs, ): """ @@ -236,6 +276,8 @@ def dot_graph( show_variables : bool, optional If True, show also the other variables (default: False). Ignored if `show_only_variable` is not None. + show_inout_arrows : bool, optional + if True, show references to inout variables as dotted lines. (default: True) **kwargs Additional keyword arguments to forward to `to_graphviz`. @@ -262,6 +304,7 @@ def dot_graph( show_only_variable=show_only_variable, show_inputs=show_inputs, show_variables=show_variables, + show_inout_arrows=show_inout_arrows, **kwargs, ) diff --git a/xsimlab/model.py b/xsimlab/model.py index 135aaf6..f88e4e4 100644 --- a/xsimlab/model.py +++ b/xsimlab/model.py @@ -14,7 +14,7 @@ RuntimeSignal, SimulationStage, ) -from .utils import AttrMapping, Frozen, variables_dict +from .utils import AttrMapping, Frozen, variables_dict, as_variable_key from .formatting import repr_model @@ -120,6 +120,7 @@ def __init__(self, processes_cls): self._dep_processes = None self._sorted_processes = None + self._deps_dict = None # a cache for group keys self._group_keys = {} @@ -401,11 +402,13 @@ def get_processes_to_validate(self): return {k: list(v) for k, v in processes_to_validate.items()} - def get_process_dependencies(self): + def get_process_dependencies(self, custom_dependencies={}): """Return a dictionary where keys are each process of the model and values are lists of the names of dependent processes (or empty lists for processes that have no dependencies). + inputs: dependencies: a {'p_name':['dep_p_name','dep2_p_name']} dictionary + Process 1 depends on process 2 if the later declares a variable (resp. a foreign variable) with intent='out' that itself (resp. its target variable) is needed in process 1. @@ -423,6 +426,10 @@ def get_process_dependencies(self): ] ) + # actually add custom dependencies + for p_name, deps in custom_dependencies.items(): + self._dep_processes[p_name].update(deps) + for p_name, p_obj in self._processes_obj.items(): for var in filter_variables(p_obj, intent=VarIntent.OUT).values(): if var.metadata["var_type"] == VarType.ON_DEMAND: @@ -435,13 +442,153 @@ def get_process_dependencies(self): self._dep_processes[pn].add(p_name) self._dep_processes = {k: list(v) for k, v in self._dep_processes.items()} - return self._dep_processes + def _check_inout_vars(self): + """ + checks if all inout variables and corresponding in variables are explicitly set in the dependencies + Out variables always come first, since the get_process_dependencies checks for that. + A well-behaved graph looks like: + ``` + inout1->inout2 + ^ \ ^ \ + / \ / \ + in in in + ``` + needs to be run after _sort_processes + """ + # create dictionaries with all inout variables and input variables + inout_dict = {} # dict of {key:{p1_name,p2_name}} for inout variables + in_dict = {} + + # TODO: improve this: the aim is to create a {key:{p1,p2,p3}} dict, + # where p1,p2,p3 are process names that have the key var as inout, resp. in vars + # some problems are that we can have on_demand and state varibles, + # that key can return a tuple or list, + for p_name, p_obj in self._processes_obj.items(): + # create {key:{p1_name,p2_name}} dicts for in and inout vars. + for var in filter_variables(p_obj, intent=VarIntent.INOUT): + if var in p_obj.__xsimlab_state_keys__: + keys = p_obj.__xsimlab_state_keys__[var] + else: + keys = p_obj.__xsimlab_od_keys__[var] + + if type(keys) == tuple: + keys = [keys] + + for key in keys: + if not key in inout_dict: + inout_dict[key] = {p_name} + else: + inout_dict[key].add(p_name) + + for var in filter_variables(p_obj, intent=VarIntent.IN): + if var in p_obj.__xsimlab_state_keys__: + keys = p_obj.__xsimlab_state_keys__[var] + else: + keys = p_obj.__xsimlab_od_keys__[var] + + if type(keys) == tuple: + keys = [keys] + + for key in keys: + if not key in in_dict: + in_dict[key] = {p_name} + else: + in_dict[key].add(p_name) + + # filter out variables that do not need to be checked (without inputs): + inout_dict = {k: v for k, v in inout_dict.items() if k in in_dict} + + for key, inout_ps in inout_dict.items(): + in_ps = in_dict[key] + + verified_ios = [] + + # now we only have to search and verify all inout variables + # print("checking ", key, " with io processes ", inout_ps) + for io_p in inout_ps: + io_stack = [io_p] + while io_stack: + cur = io_stack[-1] + if cur in verified_ios: + io_stack.pop() + continue + + child_ios = self._deps_dict[io_p].intersection(inout_ps - {cur}) + if child_ios: + # TODO: fix this with intersections + # lost_children = child_ios.symetric_difference(set(verified_ios)) + if child_ios == set(verified_ios): + child_ins = in_ps.intersection(self._deps_dict[cur]) + # verify that all children have the previous io as dependency + for child_in in child_ins: + if not verified_ios[-1] in self._deps_dict[child_in]: + raise RuntimeError( + f"inout process {verified_ios[-1]} not in {child_in}'s " + + "dependencies, could not establish strict dependency order" + ) + # we can now safely remove these in nodes + in_ps -= child_ins + verified_ios.append(cur) + io_stack.pop() + elif child_ios - set(verified_ios): + # we need to search deeper: add to the stack. + io_stack.extend( + [io for io in child_ios if io not in verified_ios] + ) + else: + raise RuntimeError( + f"inout process {cur} depends on {child_ios}, but should\ + depend on all of {verified_ios}, especially {verified_ios[-1]}" + ) + else: + # we are at the bottom inout process: remove in variables from the set + # this can only happen if we are the first process at the bottom + if verified_ios: + raise RuntimeError( + f"inout process {cur} has no dependencies with variable {key}, but {verified_ios} should be one", + ) + in_ps -= self._deps_dict[cur] + verified_ios.append(cur) + io_stack.pop() + + # we finished all inout, and inputs that are descendants of inout + # vars, so all remaining input vars shoudl depend on the last inout var + for p in in_ps: + if not verified_ios[-1] in self._deps_dict[p]: + raise RuntimeError( + f"process {verified_ios[-1]} not in depdendencies of {p} while {key} requires so" + ) + + def transitive_reduction(self): + """Returns transitive reduction of a directed graph + + The transitive reduction of G = (V,E) is a graph G- = (V,E-) such that + for all v,w in V there is an edge (v,w) in E- if and only if (v,w) is + in E and there is no path from v to w in G with length greater than 1. + + needs to be run after _sort_processes + References + ---------- + https://en.wikipedia.org/wiki/Transitive_reduction + adapted from networkx: https://networkx.org/documentation/stable/_modules/networkx/algorithms/dag.html#transitive_reduction + + """ + + for p_name in self._dep_processes: + p_nbrs = set(self._dep_processes[p_name]) + for dep_p in self._dep_processes[p_name]: + p_nbrs -= self._deps_dict[dep_p] + self._dep_processes[p_name] = list(p_nbrs) + def _sort_processes(self): """Sort processes based on their dependencies (return a list of sorted process names). + new in 0.6.0: now also returns a dictionary of {'p_name':{des,cen,dants}} + for strict checking and transitive reduction. + Stack-based depth-first search traversal. This is based on Tarjan's method for topological sorting. @@ -455,6 +602,7 @@ def _sort_processes(self): """ ordered = [] + self._deps_dict = {p: set() for p in self._dep_processes} # Nodes whose descendents have been completely explored. # These nodes are guaranteed to not be part of a cycle. @@ -484,18 +632,19 @@ def _sort_processes(self): # Add direct descendants of cur to nodes stack next_nodes = [] for nxt in self._dep_processes[cur]: - if nxt not in completed: - if nxt in seen: - # Cycle detected! - cycle = [nxt] - while nodes[-1] != nxt: - cycle.append(nodes.pop()) + if nxt in seen: + # Cycle detected! + cycle = [nxt] + while nodes[-1] != nxt: cycle.append(nodes.pop()) - cycle.reverse() - cycle = "->".join(cycle) - raise RuntimeError( - f"Cycle detected in process graph: {cycle}" - ) + cycle.append(nodes.pop()) + cycle.reverse() + cycle = "->".join(cycle) + raise RuntimeError(f"Cycle detected in process graph: {cycle}") + if nxt in completed: + self._deps_dict[cur].add(nxt) + self._deps_dict[cur].update(self._deps_dict[nxt]) + else: next_nodes.append(nxt) if next_nodes: @@ -509,7 +658,8 @@ def _sort_processes(self): nodes.pop() return ordered - def get_sorted_processes(self): + def get_sorted_processes(self, strict_check=False, transitive_reduce=False): + self._sorted_processes = OrderedDict( [(p_name, self._processes_obj[p_name]) for p_name in self._sort_processes()] ) @@ -534,13 +684,23 @@ class Model(AttrMapping): active = [] - def __init__(self, processes): + def __init__( + self, + processes, + custom_dependencies={}, + strict_check=False, + transitive_reduce=False, + ): """ Parameters ---------- processes : dict Dictionnary with process names as keys and classes (decorated with :func:`process`) as values. + custom_dependencies : dict + Dictionary with dependencies of processes wher this is not clear from + the model, in the case of intent='inout' variables. the dictionary should be in the form: + {('process_name','variable_name'):'dependent_process_name'} or {'p_name__var_name':'dep_p_name'} Raises ------ @@ -572,9 +732,28 @@ def __init__(self, processes): self._processes_to_validate = builder.get_processes_to_validate() - self._dep_processes = builder.get_process_dependencies() + # clean custom dependencies + self._custom_dependencies = {} + for p_name, c_deps in custom_dependencies.items(): + c_deps = ( + {c_deps} if isinstance(c_deps, str) else {c_dep for c_dep in c_deps} + ) + self._custom_dependencies[p_name] = c_deps + + self._dep_processes = builder.get_process_dependencies( + self._custom_dependencies + ) + self._processes = builder.get_sorted_processes() + # these depend on the deps_dict created in sort_processes: + self._strict_check = strict_check + self._transitive_reduce = transitive_reduce + if self._strict_check: + builder._check_inout_vars() + if self._transitive_reduce: + builder.transitive_reduction() + super(Model, self).__init__(self._processes) self._initialized = True @@ -654,9 +833,13 @@ def dependent_processes(self): return self._dep_processes def visualize( - self, show_only_variable=None, show_inputs=False, show_variables=False + self, + show_only_variable=None, + show_inputs=False, + show_variables=False, + show_inout_arrows=True, ): - """Render the model as a graph using dot (require graphviz). + """Render the model as a graph using dot (requires graphviz). Parameters ---------- @@ -683,6 +866,7 @@ def visualize( show_only_variable=show_only_variable, show_inputs=show_inputs, show_variables=show_variables, + show_inout_arrows=show_inout_arrows, ) @property @@ -1035,11 +1219,16 @@ def clone(self): Returns ------- cloned : Model - New Model instance with the same processes. + New Model instance with the same processes. and defined dependencies """ processes_cls = {k: type(obj) for k, obj in self._processes.items()} - return type(self)(processes_cls) + return type(self)( + processes_cls, + self._custom_dependencies, + self._strict_check, + self._transitive_reduce, + ) def update_processes(self, processes): """Add or replace processe(s) in this model. @@ -1058,14 +1247,19 @@ def update_processes(self, processes): """ processes_cls = {k: type(obj) for k, obj in self._processes.items()} processes_cls.update(processes) - return type(self)(processes_cls) + return type(self)( + processes_cls, + self._custom_dependencies, + self._strict_check, + self._transitive_reduce, + ) def drop_processes(self, keys): """Drop processe(s) from this model. Parameters ---------- - keys : str or list of str + keys : str or iterable of str Name(s) of the processes to drop. Returns @@ -1074,13 +1268,56 @@ def drop_processes(self, keys): New Model instance with dropped processes. """ - if isinstance(keys, str): - keys = [keys] + keys = {keys} if isinstance(keys, str) else {key for key in keys} processes_cls = { k: type(obj) for k, obj in self._processes.items() if k not in keys } - return type(self)(processes_cls) + + # nooo we also should check for chains of deps e.g. + # a->b->c->d->e where {b,c,d} are removed + # wake me up when the depndencies end... + # here comes the stack again... defining who we are... + # start a DFS only on these keys again... + # actually it is only dfs on custom deps, so not too bad + # let's see if we can do it in-place + completed = set() + for key in self._custom_dependencies: + if key in completed: + continue + key_stack = [key] + while key_stack: + cur = key_stack[-1] + if cur in completed: + key_stack.pop() + continue + + child_keys = keys.intersection(self._custom_dependencies[cur]) + if child_keys.issubset(completed): + # all children are added, so we are safe + self._custom_dependencies[cur].update( + *[ + self._custom_dependencies[child_key] + for child_key in child_keys + ] + ) + self._custom_dependencies[cur] -= child_keys + completed.add(cur) + key_stack.pop() + else: # if child_keys - completed: + # we need to search deeper: add to the stack. + key_stack.extend([k for k in child_keys - completed]) + + # that was actually quite ok.. now also remove keys from custom deps + for key in keys: + del self._custom_dependencies[key] + + return type(self)( + processes_cls, + self._custom_dependencies, + self._strict_check, + self._transitive_reduce, + ) def __eq__(self, other): if not isinstance(other, self.__class__): diff --git a/xsimlab/tests/test_dot.py b/xsimlab/tests/test_dot.py index ec9fb17..947ec9c 100644 --- a/xsimlab/tests/test_dot.py +++ b/xsimlab/tests/test_dot.py @@ -55,7 +55,7 @@ def _ensure_not_exists(filename): def test_to_graphviz(model): - g = to_graphviz(model) + g = to_graphviz(model, show_inout_arrows=False) actual_nodes = _get_graph_nodes(g) actual_edges = _get_graph_edges(g) expected_nodes = list(model) diff --git a/xsimlab/tests/test_variable.py b/xsimlab/tests/test_variable.py index 159edf5..9efb1c9 100644 --- a/xsimlab/tests/test_variable.py +++ b/xsimlab/tests/test_variable.py @@ -104,9 +104,6 @@ class Foo: def test_foreign(): - with pytest.raises(ValueError, match="intent='inout' is not supported.*"): - xs.foreign(ExampleProcess, "some_var", intent="inout") - var = attr.fields(ExampleProcess).out_foreign_var ref_var = attr.fields(AnotherProcess).another_var diff --git a/xsimlab/utils.py b/xsimlab/utils.py index bbf2ddd..a74a7c6 100644 --- a/xsimlab/utils.py +++ b/xsimlab/utils.py @@ -43,6 +43,33 @@ def __repr__(self): """ +def as_variable_key(key): + """Returns ``key`` as a tuple of the form + ``('process_name', 'var_name')``. + + If ``key`` is given as a string, then process name and variable + name must be separated unambiguously by '__' (double underscore) + and must not be empty. + + """ + key_tuple = None + + if isinstance(key, tuple) and len(key) == 2: + key_tuple = key + + elif isinstance(key, str): + key_split = key.split("__") + if len(key_split) == 2: + p_name, var_name = key_split + if p_name and var_name: + key_tuple = (p_name, var_name) + + if key_tuple is None: + raise ValueError(f"{key!r} is not a valid input variable key") + + return key_tuple + + def variables_dict(process_cls): """Get all xsimlab variables declared in a process. diff --git a/xsimlab/variable.py b/xsimlab/variable.py index d6b642a..1aedb97 100644 --- a/xsimlab/variable.py +++ b/xsimlab/variable.py @@ -445,8 +445,6 @@ def foreign(other_process_cls, var_name, intent="in"): model. """ - if intent == "inout": - raise ValueError("intent='inout' is not supported for foreign variables") ref_var = attr.fields_dict(other_process_cls)[var_name] diff --git a/xsimlab/xr_accessor.py b/xsimlab/xr_accessor.py index a12343d..8fb9d3f 100644 --- a/xsimlab/xr_accessor.py +++ b/xsimlab/xr_accessor.py @@ -11,7 +11,7 @@ from .drivers import XarraySimulationDriver from .model import get_model_variables, Model -from .utils import Frozen, variables_dict +from .utils import Frozen, variables_dict, as_variable_key from .variable import VarType @@ -44,33 +44,6 @@ def _maybe_get_model_from_context(model): return model -def as_variable_key(key): - """Returns ``key`` as a tuple of the form - ``('process_name', 'var_name')``. - - If ``key`` is given as a string, then process name and variable - name must be separated unambiguously by '__' (double underscore) - and must not be empty. - - """ - key_tuple = None - - if isinstance(key, tuple) and len(key) == 2: - key_tuple = key - - elif isinstance(key, str): - key_split = key.split("__") - if len(key_split) == 2: - p_name, var_name = key_split - if p_name and var_name: - key_tuple = (p_name, var_name) - - if key_tuple is None: - raise ValueError(f"{key!r} is not a valid input variable key") - - return key_tuple - - def _flatten_inputs(input_vars): """Returns ``input_vars`` as a flat dictionary where keys are tuples in the form ``(process_name, var_name)``. Raises an error if the