diff --git a/bonemarketsolver/objects/bone_market_model.py b/bonemarketsolver/objects/bone_market_model.py index dbc5aad..761d025 100644 --- a/bonemarketsolver/objects/bone_market_model.py +++ b/bonemarketsolver/objects/bone_market_model.py @@ -99,18 +99,26 @@ Each parameter is interpreted as a BoundedLinearExpression, and a layer of indir return intermediate @singledispatchmethod - def NewIntermediateIntVar(self, expression: cp_model.LinearExpr, name: str, *, lb: Integral = cp_model.INT32_MIN, ub: Integral = cp_model.INT32_MAX) -> tuple[cp_model.IntVar, cp_model.Constraint]: - """Creates an integer variable equivalent to the given expression and returns a tuple consisting of the variable and constraint for use with enforcement literals. + def NewIntermediateIntVar(self, expression: cp_model.LinearExpr, name: str, *, lb: Integral = cp_model.INT32_MIN, ub: Integral = cp_model.INT32_MAX) -> tuple: + """Creates an integer variable equivalent to the given expression and returns a tuple consisting of the variable and constraints for use with enforcement literals. `equality` must be either a LinearExp or a unary partialmethod that accepts a target integer variable and returns Constraints.""" - intermediate: Final[cp_model.IntVar] = super().NewIntVar(lb, ub, name) - return (intermediate, self.Add(intermediate == expression)) + # If expression is either an integer variable with the specified bounds or an integer constant, just pass it through + if isinstance(expression, cp_model.IntVar) and (lambda domain : domain == [lb, ub] or domain[0] == domain[1])(cp_model.IntVar.Proto(expression).domain): + return (expression, ()) + else: + intermediate: Final[cp_model.IntVar] = super().NewIntVar(lb, ub, name) + return (intermediate, self.Add(intermediate == expression)) @NewIntermediateIntVar.register def _(self, expression: partialmethod, name: str, *, lb: Integral = cp_model.INT32_MIN, ub: Integral = cp_model.INT32_MAX) -> tuple: intermediate: Final[cp_model.IntVar] = super().NewIntVar(lb, ub, name) return (intermediate, expression.__get__(self)(intermediate)) + @NewIntermediateIntVar.register + def _(self, expression: Integral, *args, **kwargs) -> tuple: + return (self.NewConstant(expression), ()) + def NewIntVar(self, name: str, *, lb: Integral = cp_model.INT32_MIN, ub: Integral = cp_model.INT32_MAX) -> cp_model.IntVar: return super().NewIntVar(lb, ub, name)