diff --git a/bonemarketsolver/objects/bone_market_model.py b/bonemarketsolver/objects/bone_market_model.py index cab48b1..6e24ddc 100644 --- a/bonemarketsolver/objects/bone_market_model.py +++ b/bonemarketsolver/objects/bone_market_model.py @@ -1,6 +1,6 @@ __author__ = "Jeremy Saklad" -from functools import cache, partialmethod, reduce, singledispatch +from functools import cache, partialmethod, reduce, singledispatch, singledispatchmethod from ortools.sat.python import cp_model @@ -126,11 +126,19 @@ Each parameter is interpreted as a BoundedLinearExpression, and a layer of indir self.AddLinearExpressionInDomain(linear_exp, domain.Complement()).OnlyEnforceIf(intermediate.Not()) return intermediate - def NewIntermediateIntVar(self, linear_exp, name, *, lb = cp_model.INT32_MIN, ub = cp_model.INT32_MAX): - """Creates an integer variable equivalent to the given expression and returns a tuple consisting of the variable and constraint for use with enforcement literals.""" + @singledispatchmethod + def NewIntermediateIntVar(self, expression: cp_model.LinearExpr, name, *, lb = cp_model.INT32_MIN, ub = cp_model.INT32_MAX): + """Creates an integer variable equivalent to the given expression and returns a tuple consisting of the variable and constraint for use with enforcement literals. + +`equality` must be either a LinearExp or a unary partialmethod that accepts a target integer variable and returns Constraints.""" intermediate = super().NewIntVar(lb, ub, name) - return (intermediate, self.Add(intermediate == linear_exp)) + return (intermediate, self.Add(intermediate == expression)) + + @NewIntermediateIntVar.register + def _(self, expression: partialmethod, name, *, lb = cp_model.INT32_MIN, ub = cp_model.INT32_MAX): + intermediate = super().NewIntVar(lb, ub, name) + return (intermediate, expression.__get__(self)(intermediate)) def NewIntVar(self, name, *, lb = cp_model.INT32_MIN, ub = cp_model.INT32_MAX): return super().NewIntVar(lb, ub, name)