Add type hints to helper functions
For the sake of maintainability, it is very important the parameters and results of the helper functions are spelled out explicitly. Union type hints are being left until Python 3.10 is supported.
This commit is contained in:
parent
40246c4815
commit
e211becb30
|
@ -1,20 +1,26 @@
|
|||
__author__ = "Jeremy Saklad"
|
||||
__author__: str = "Jeremy Saklad"
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import cache, partialmethod, reduce, singledispatch, singledispatchmethod
|
||||
from numbers import Integral, Number
|
||||
from typing import Final
|
||||
|
||||
from ortools.sat.python import cp_model
|
||||
|
||||
class BoneMarketModel(cp_model.CpModel):
|
||||
"""A CpModel with additional functions for common constraints and enhanced enforcement literal support."""
|
||||
|
||||
__slots__ = ()
|
||||
__slots__: tuple[()] = ()
|
||||
|
||||
def AddAllowedAssignments(self, variables, tuples_list):
|
||||
intermediate_variables, constraints = zip(*(self.NewIntermediateIntVar(variable, f'{repr((variables, tuples_list))}: {variable}') for variable in variables))
|
||||
def AddAllowedAssignments(self, variables: Iterable[Iterable], tuples_list: Iterable[Iterable]) -> tuple:
|
||||
# Used for variable names
|
||||
invocation: Final[str] = repr((variables, tuples_list))
|
||||
|
||||
intermediate_variables, constraints = zip(*(self.NewIntermediateIntVar(variable, f'{invocation}: {variable}') for variable in variables))
|
||||
super().AddAllowedAssignments(intermediate_variables, tuples_list)
|
||||
return constraints
|
||||
|
||||
def AddApproximateExponentiationEquality(self, target, var, exp, upto):
|
||||
def AddApproximateExponentiationEquality(self, target, var, exp: Number, upto: Integral) -> tuple:
|
||||
"""Add an approximate exponentiation equality using a lookup table.
|
||||
|
||||
Set `upto` to a value that is unlikely to come into play.
|
||||
|
@ -22,35 +28,35 @@ Set `upto` to a value that is unlikely to come into play.
|
|||
Each parameter is interpreted as a BoundedLinearExpression, and a layer of indirection is applied such that each Constraint in the returned tuple can accept an enforcement literal."""
|
||||
return self.AddAllowedAssignments((target, var), ((int(base**exp), base) for base in range(upto + 1)))
|
||||
|
||||
def AddDivisionEquality(self, target, num, denom):
|
||||
def AddDivisionEquality(self, target, num, denom) -> tuple:
|
||||
"""Adds `target == num // denom` (integer division rounded towards 0).
|
||||
|
||||
Each parameter is interpreted as a BoundedLinearExpression, and a layer of indirection is applied such that each Constraint in the returned tuple can accept an enforcement literal."""
|
||||
intermediate_target, target_constraint = self.NewIntermediateIntVar(target, f'{repr(target)} == {repr(num)} // {repr(denom)}: target')
|
||||
intermediate_num, num_constraint = self.NewIntermediateIntVar(num, f'{repr(target)} == {repr(num)} // {repr(denom)}: num', lb = 0)
|
||||
intermediate_denom, denom_constraint = self.NewIntermediateIntVar(denom, f'{repr(target)} == {repr(num)} // {repr(denom)}: denom', lb = 1)
|
||||
# Used for variable names
|
||||
invocation: Final[str] = f'{repr(target)} == {repr(num)} // {repr(denom)}'
|
||||
|
||||
intermediate_target, target_constraint = self.NewIntermediateIntVar(target, f'{invocation}: target')
|
||||
intermediate_num, num_constraint = self.NewIntermediateIntVar(num, f'{invocation}: num', lb=0)
|
||||
intermediate_denom, denom_constraint = self.NewIntermediateIntVar(denom, f'{invocation}: denom', lb=1)
|
||||
|
||||
super().AddDivisionEquality(intermediate_target, intermediate_num, intermediate_denom)
|
||||
return (target_constraint, num_constraint, denom_constraint)
|
||||
|
||||
def AddIf(self, variable, *constraints):
|
||||
def AddIf(self, variable, *constraints: tuple) -> frozenset:
|
||||
"""Add constraints to the model, only enforced if the specified variable is true.
|
||||
|
||||
Each item in `constraints` must be either a BoundedLinearExpression, a Constraint compatible with OnlyEnforceIf, a 0-arity partial method of CpModel returning a valid item, or an iterable containing valid items."""
|
||||
|
||||
@singledispatch
|
||||
def Add(constraint):
|
||||
if constraint_iterator := iter(constraint):
|
||||
return frozenset((Add(element) for element in constraint_iterator))
|
||||
else:
|
||||
raise TypeError(f"Invalid constraint: {repr(constraint)}")
|
||||
def Add(constraint: Iterable) -> frozenset:
|
||||
return frozenset((Add(element) for element in constraint))
|
||||
|
||||
@Add.register
|
||||
def _(constraint: cp_model.Constraint):
|
||||
def _(constraint: cp_model.Constraint) -> cp_model.Constraint:
|
||||
return constraint.OnlyEnforceIf(variable)
|
||||
|
||||
@Add.register
|
||||
def _(constraint: cp_model.BoundedLinearExpression):
|
||||
def _(constraint: cp_model.BoundedLinearExpression) -> cp_model.Constraint:
|
||||
return Add(self.Add(constraint))
|
||||
|
||||
@Add.register
|
||||
|
@ -59,18 +65,18 @@ Each item in `constraints` must be either a BoundedLinearExpression, a Constrain
|
|||
|
||||
return frozenset((Add(constraint) for constraint in constraints))
|
||||
|
||||
def AddMultiplicationEquality(self, target, variables):
|
||||
def AddMultiplicationEquality(self, target, variables: Iterable) -> tuple:
|
||||
"""Adds `target == variables[0] * .. * variables[n]`.
|
||||
|
||||
Each parameter is interpreted as a BoundedLinearExpression, and a layer of indirection is applied such that each Constraint in the returned tuple can accept an enforcement literal."""
|
||||
|
||||
superclass = super()
|
||||
superclass: Final = super()
|
||||
|
||||
def Multiply(end, stack):
|
||||
def Multiply(end, stack: list) -> tuple:
|
||||
intermediate_variable, variable_constraint = self.NewIntermediateIntVar(stack.pop(), f'{repr(end)} == {"*".join((repr(variable) for variable in stack))}: last variable')
|
||||
|
||||
partial_target = self.NewIntVar(f'{repr(end)} == {"*".join((repr(variable) for variable in stack))}: partial target')
|
||||
recursive_constraints = self.AddMultiplicationEquality(partial_target, stack) if len(stack) > 1 else (self.Add(partial_target == stack.pop()),)
|
||||
partial_target: Final[cp_model.IntVar] = self.NewIntVar(f'{repr(end)} == {"*".join((repr(variable) for variable in stack))}: partial target')
|
||||
recursive_constraints: Final[tuple] = self.AddMultiplicationEquality(partial_target, stack) if len(stack) > 1 else (self.Add(partial_target == stack.pop()),)
|
||||
|
||||
intermediate_target, target_constraint = self.NewIntermediateIntVar(end, f'{repr(end)} == {"*".join((repr(variable) for variable in stack))}: target')
|
||||
|
||||
|
@ -82,29 +88,29 @@ Each parameter is interpreted as a BoundedLinearExpression, and a layer of indir
|
|||
return Multiply(target, variables.copy() if isinstance(variables, list) else list(variables))
|
||||
|
||||
@cache
|
||||
def BoolExpression(self, bounded_linear_exp):
|
||||
def BoolExpression(self, bounded_linear_exp: cp_model.BoundedLinearExpression) -> cp_model.IntVar:
|
||||
"""Add a fully-reified implication using an intermediate Boolean variable."""
|
||||
|
||||
intermediate = self.NewBoolVar(str(bounded_linear_exp))
|
||||
linear_exp = bounded_linear_exp.Expression()
|
||||
domain = cp_model.Domain(*bounded_linear_exp.Bounds())
|
||||
intermediate: Final[cp_model.IntVar] = self.NewBoolVar(str(bounded_linear_exp))
|
||||
linear_exp: Final[cp_model.LinearExp] = bounded_linear_exp.Expression()
|
||||
domain: Final[cp_model.Domain] = cp_model.Domain(*bounded_linear_exp.Bounds())
|
||||
self.AddLinearExpressionInDomain(linear_exp, domain).OnlyEnforceIf(intermediate)
|
||||
self.AddLinearExpressionInDomain(linear_exp, domain.Complement()).OnlyEnforceIf(intermediate.Not())
|
||||
return intermediate
|
||||
|
||||
@singledispatchmethod
|
||||
def NewIntermediateIntVar(self, expression: cp_model.LinearExpr, name, *, lb = cp_model.INT32_MIN, ub = cp_model.INT32_MAX):
|
||||
def NewIntermediateIntVar(self, expression: cp_model.LinearExpr, name: str, *, lb: Integral = cp_model.INT32_MIN, ub: Integral = cp_model.INT32_MAX) -> tuple[cp_model.IntVar, cp_model.Constraint]:
|
||||
"""Creates an integer variable equivalent to the given expression and returns a tuple consisting of the variable and constraint for use with enforcement literals.
|
||||
|
||||
`equality` must be either a LinearExp or a unary partialmethod that accepts a target integer variable and returns Constraints."""
|
||||
|
||||
intermediate = super().NewIntVar(lb, ub, name)
|
||||
intermediate: Final[cp_model.IntVar] = super().NewIntVar(lb, ub, name)
|
||||
return (intermediate, self.Add(intermediate == expression))
|
||||
|
||||
@NewIntermediateIntVar.register
|
||||
def _(self, expression: partialmethod, name, *, lb = cp_model.INT32_MIN, ub = cp_model.INT32_MAX):
|
||||
intermediate = super().NewIntVar(lb, ub, name)
|
||||
def _(self, expression: partialmethod, name: str, *, lb: Integral = cp_model.INT32_MIN, ub: Integral = cp_model.INT32_MAX) -> tuple:
|
||||
intermediate: Final[cp_model.IntVar] = super().NewIntVar(lb, ub, name)
|
||||
return (intermediate, expression.__get__(self)(intermediate))
|
||||
|
||||
def NewIntVar(self, name, *, lb = cp_model.INT32_MIN, ub = cp_model.INT32_MAX):
|
||||
def NewIntVar(self, name: str, *, lb: Integral = cp_model.INT32_MIN, ub: Integral = cp_model.INT32_MAX) -> cp_model.IntVar:
|
||||
return super().NewIntVar(lb, ub, name)
|
||||
|
|
Loading…
Reference in New Issue