diff --git a/src/bindings/lib_param_bn/symbolic/model_color.rs b/src/bindings/lib_param_bn/symbolic/model_color.rs index de666f7..323282b 100644 --- a/src/bindings/lib_param_bn/symbolic/model_color.rs +++ b/src/bindings/lib_param_bn/symbolic/model_color.rs @@ -6,8 +6,9 @@ use crate::bindings::lib_param_bn::symbolic::symbolic_context::SymbolicContext; use crate::bindings::lib_param_bn::update_function::UpdateFunction; use crate::bindings::lib_param_bn::variable_id::VariableId; use crate::{throw_index_error, throw_type_error, AsNative}; +use biodivine_lib_bdd::boolean_expression::BooleanExpression as RsBooleanExpression; use biodivine_lib_bdd::BddPartialValuation; -use biodivine_lib_param_bn::FnUpdate; +use biodivine_lib_param_bn::{BinaryOp, FnUpdate}; use either::Either; use pyo3::types::{PyDict, PyList, PyTuple}; use pyo3::{pyclass, pymethods, IntoPy, Py, PyAny, PyObject, PyResult, Python}; @@ -177,8 +178,8 @@ impl ColorModel { /// Specifically, there are three supported modes of operation: /// - If `item` is an `UpdateFunction`, the result is a new `UpdateFunction` that only depends /// on network variables and is the interpretation of the original function under this model. - /// - If `item` is a `BooleanNetwork`, the result is a new `BooleanNetwork` with no parameters - /// where each function has been instantiated. + /// - If `item` is a `BooleanNetwork`, the result is a new `BooleanNetwork` where all + /// uninterpreted functions that are retained in this model are instantiated. /// - If `item` identifies an uninterpreted function (by `ParameterId`, `VariableId`, or /// a string name), the method returns an `UpdateFunction` that is an interpretation of the /// uninterpreted function with specified `args` under this model. This is equivalent to @@ -211,7 +212,7 @@ impl ColorModel { return self.instantiate_update_function(py, update_function); } if let Ok(network) = item.extract::>() { - // For a Boolean network, we try to instantiate every update function separated + // For a Boolean network, we try to instantiate every update function separately // and then remove all unused parameters. if args.is_some() { return throw_type_error( @@ -223,6 +224,10 @@ impl ColorModel { let function = if let Some(function) = bn.get_update_function(var) { self.instantiate_fn_update(function)? } else { + if !self.retained_implicit.contains(&var) { + // This variable is not retained, thus we can't instantiate it. + continue; + } let args = bn.regulators(var); let function_bdd = ctx.as_native().mk_implicit_function_is_true(var, &args); let instantiated_bdd = function_bdd.restrict(&self.to_values()); @@ -231,8 +236,10 @@ impl ColorModel { bn.set_update_function(var, Some(function)).unwrap(); } + let expected = (bn.num_parameters() + bn.num_implicit_parameters()) + - (self.retained_implicit.len() + self.retained_explicit.len()); let bn = bn.prune_unused_parameters(); - assert_eq!(bn.num_parameters(), 0); + assert_eq!(bn.num_parameters() + bn.num_implicit_parameters(), expected); return Ok(BooleanNetwork::from(bn).export_to_python(py)?.into_py(py)); } @@ -304,18 +311,52 @@ impl ColorModel { pub fn instantiate_fn_update(&self, fn_update: &FnUpdate) -> PyResult { let ctx = self.ctx.get().as_native(); - let mut missing_support = fn_update.collect_parameters(); - missing_support.retain(|x| !self.retained_explicit.contains(x)); - if !missing_support.is_empty() { - return throw_index_error(format!( - "Function(s) `{:?}` are not available in this projection.", - missing_support - )); + let all_fn_parameters = fn_update.collect_parameters(); + if all_fn_parameters.is_empty() { + // No need to instantiate, there are no parameters here. + return Ok(fn_update.clone()); } - let update_function_bdd = ctx.mk_fn_update_true(fn_update); - let instantiated_bdd = update_function_bdd.restrict(&self.to_values()); - Ok(FnUpdate::build_from_bdd(ctx, &instantiated_bdd)) + // Only keep parameters that are not retained in this model + let mut missing_fn_parameters = all_fn_parameters.clone(); + missing_fn_parameters.retain(|x| !self.retained_explicit.contains(x)); + if !missing_fn_parameters.is_empty() { + // We can't instantiate this function fully, but we at least fill in some blanks. + fn transform( + ctx: &ColorModel, + missing: &[biodivine_lib_param_bn::ParameterId], + fun: &FnUpdate, + ) -> PyResult { + match fun { + FnUpdate::Const(_) | FnUpdate::Var(_) => Ok(fun.clone()), + FnUpdate::Param(id, args) => { + let args = args + .iter() + .map(|it| transform(ctx, missing, it)) + .collect::>>()?; + if missing.contains(id) { + Ok(FnUpdate::mk_param(*id, &args)) + } else { + ctx.instantiate_explicit_parameter(*id, &args) + } + } + FnUpdate::Not(inner) => Ok(FnUpdate::mk_not(transform(ctx, missing, inner)?)), + FnUpdate::Binary(op, a, b) => Ok(FnUpdate::mk_binary( + *op, + transform(ctx, missing, a)?, + transform(ctx, missing, b)?, + )), + } + } + + transform(self, &missing_fn_parameters, fn_update) + } else { + // Everything unknown in this function is covered. We can instantiate it through + // a BDD. This should be slightly more compact for complex functions. + let update_function_bdd = ctx.mk_fn_update_true(fn_update); + let instantiated_bdd = update_function_bdd.restrict(&self.to_values()); + Ok(FnUpdate::build_from_bdd(ctx, &instantiated_bdd)) + } } /// Turn a function into a `BooleanExpression` using anonymous variable names. @@ -379,4 +420,65 @@ impl ColorModel { .collect::>(); Ok(BooleanExpression::mk_disjunction(clauses)) } + + pub fn instantiate_explicit_parameter( + &self, + par: biodivine_lib_param_bn::ParameterId, + args: &[FnUpdate], + ) -> PyResult { + let ctx = self.ctx.get(); + let table = ctx.as_native().get_explicit_function_table(par); + assert_eq!(args.len(), usize::from(table.arity)); + + fn transform(expr: &RsBooleanExpression, args: &[FnUpdate]) -> FnUpdate { + match expr { + RsBooleanExpression::Const(val) => FnUpdate::Const(*val), + RsBooleanExpression::Variable(var) => { + let mut split = var.split('_'); + split.next().unwrap(); + let id = split.next().unwrap(); + assert!(split.next().is_none()); + let id = id.parse::().unwrap(); + args[id].clone() + } + RsBooleanExpression::Not(inner) => FnUpdate::mk_not(transform(inner, args)), + RsBooleanExpression::And(a, b) => { + let a = transform(a, args); + let b = transform(b, args); + FnUpdate::mk_binary(BinaryOp::And, a, b) + } + RsBooleanExpression::Or(a, b) => { + let a = transform(a, args); + let b = transform(b, args); + FnUpdate::mk_binary(BinaryOp::Or, a, b) + } + RsBooleanExpression::Xor(a, b) => { + let a = transform(a, args); + let b = transform(b, args); + FnUpdate::mk_binary(BinaryOp::Xor, a, b) + } + RsBooleanExpression::Imp(a, b) => { + let a = transform(a, args); + let b = transform(b, args); + FnUpdate::mk_binary(BinaryOp::Imp, a, b) + } + RsBooleanExpression::Iff(a, b) => { + let a = transform(a, args); + let b = transform(b, args); + FnUpdate::mk_binary(BinaryOp::Iff, a, b) + } + RsBooleanExpression::Cond(a, b, c) => { + let a = transform(a, args); + let b = transform(b, args); + let c = transform(c, args); + let cond_1 = FnUpdate::mk_binary(BinaryOp::Imp, a.clone(), b); + let cond_2 = FnUpdate::mk_binary(BinaryOp::Imp, FnUpdate::mk_not(a.clone()), c); + FnUpdate::mk_binary(BinaryOp::And, cond_1, cond_2) + } + } + } + + let expr = self.instantiate_expression(Right(par))?; + Ok(transform(expr.as_native(), args)) + } } diff --git a/tests/test_param_bn_module.py b/tests/test_param_bn_module.py index 7a3b6ef..f792256 100644 --- a/tests/test_param_bn_module.py +++ b/tests/test_param_bn_module.py @@ -1009,6 +1009,18 @@ def test_symbolic_iterators(): assert i.to_symbolic().is_singleton() assert i.to_symbolic().is_subset(graph.mk_function_colors("f", i["f"])) + # Instantiation with a subset of networks: + for i in unit_colors.items(retained=["f"]): + i_bn = i.instantiate(bn) + assert i_bn.explicit_parameter_count() == 0 + assert i_bn.implicit_parameter_count() == 2 + + fn_b = bn.get_update_function("b") + assert fn_b is not None + fn_b = i.instantiate(fn_b) + assert str(fn_b) in {"a", "a & c", "a & !c"} + + # This is basically a mix of tests for ColorSet and VertexSet unit_colored_set = graph.mk_unit_colored_vertices()