From 6f0e0ac379afd63cce56c4bac96c9921daa31ce1 Mon Sep 17 00:00:00 2001 From: Jeremy Saklad Date: Mon, 27 Sep 2021 13:09:00 -0500 Subject: [PATCH] Add BoneMarketModel.AddIf method This method allows a series of constraints to be applied all at once using the same enforcement literal, which can substantially improve readability and writability. --- bonemarketsolver/objects/bone_market_model.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/bonemarketsolver/objects/bone_market_model.py b/bonemarketsolver/objects/bone_market_model.py index 30a576a..1b7f524 100644 --- a/bonemarketsolver/objects/bone_market_model.py +++ b/bonemarketsolver/objects/bone_market_model.py @@ -1,6 +1,6 @@ __author__ = "Jeremy Saklad" -from functools import cache, reduce +from functools import cache, partialmethod, reduce, singledispatch from ortools.sat.python import cp_model @@ -22,6 +22,32 @@ Set `upto` to a value that is unlikely to come into play. Each parameter is interpreted as a BoundedLinearExpression, and a layer of indirection is applied such that each Constraint in the returned tuple can accept an enforcement literal.""" return self.AddAllowedAssignments((target, var), ((int(base**exp), base) for base in range(upto + 1))) + def AddIf(self, variable, *constraints): + """Add constraints to the model, only enforced if the specified variable is true. + +Each item in `constraints` must be either a BoundedLinearExpression, a Constraint compatible with OnlyEnforceIf, a 0-arity partial method of CpModel returning a valid item, or an iterable containing valid items.""" + + @singledispatch + def Add(constraint): + if constraint_iterator := iter(constraint): + return frozenset((Add(element) for element in constraint_iterator)) + else: + raise TypeError(f"Invalid constraint: {repr(constraint)}") + + @Add.register + def _(constraint: cp_model.Constraint): + return constraint.OnlyEnforceIf(variable) + + @Add.register + def _(constraint: cp_model.BoundedLinearExpression): + return Add(self.Add(constraint)) + + @Add.register + def _(constraint: partialmethod): + return Add(constraint.__get__(self)()) + + return frozenset((Add(constraint) for constraint in constraints)) + def AddMultiplicationEquality(self, target, variables): """Adds `target == variables[0] * .. * variables[n]`.