Add type hints to helper functions
For the sake of maintainability, it is very important the parameters and results of the helper functions are spelled out explicitly. Union type hints are being left until Python 3.10 is supported.
This commit is contained in:
parent
40246c4815
commit
e211becb30
|
@ -1,20 +1,26 @@
|
||||||
__author__ = "Jeremy Saklad"
|
__author__: str = "Jeremy Saklad"
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
from functools import cache, partialmethod, reduce, singledispatch, singledispatchmethod
|
from functools import cache, partialmethod, reduce, singledispatch, singledispatchmethod
|
||||||
|
from numbers import Integral, Number
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
from ortools.sat.python import cp_model
|
from ortools.sat.python import cp_model
|
||||||
|
|
||||||
class BoneMarketModel(cp_model.CpModel):
|
class BoneMarketModel(cp_model.CpModel):
|
||||||
"""A CpModel with additional functions for common constraints and enhanced enforcement literal support."""
|
"""A CpModel with additional functions for common constraints and enhanced enforcement literal support."""
|
||||||
|
|
||||||
__slots__ = ()
|
__slots__: tuple[()] = ()
|
||||||
|
|
||||||
def AddAllowedAssignments(self, variables, tuples_list):
|
def AddAllowedAssignments(self, variables: Iterable[Iterable], tuples_list: Iterable[Iterable]) -> tuple:
|
||||||
intermediate_variables, constraints = zip(*(self.NewIntermediateIntVar(variable, f'{repr((variables, tuples_list))}: {variable}') for variable in variables))
|
# Used for variable names
|
||||||
|
invocation: Final[str] = repr((variables, tuples_list))
|
||||||
|
|
||||||
|
intermediate_variables, constraints = zip(*(self.NewIntermediateIntVar(variable, f'{invocation}: {variable}') for variable in variables))
|
||||||
super().AddAllowedAssignments(intermediate_variables, tuples_list)
|
super().AddAllowedAssignments(intermediate_variables, tuples_list)
|
||||||
return constraints
|
return constraints
|
||||||
|
|
||||||
def AddApproximateExponentiationEquality(self, target, var, exp, upto):
|
def AddApproximateExponentiationEquality(self, target, var, exp: Number, upto: Integral) -> tuple:
|
||||||
"""Add an approximate exponentiation equality using a lookup table.
|
"""Add an approximate exponentiation equality using a lookup table.
|
||||||
|
|
||||||
Set `upto` to a value that is unlikely to come into play.
|
Set `upto` to a value that is unlikely to come into play.
|
||||||
|
@ -22,35 +28,35 @@ Set `upto` to a value that is unlikely to come into play.
|
||||||
Each parameter is interpreted as a BoundedLinearExpression, and a layer of indirection is applied such that each Constraint in the returned tuple can accept an enforcement literal."""
|
Each parameter is interpreted as a BoundedLinearExpression, and a layer of indirection is applied such that each Constraint in the returned tuple can accept an enforcement literal."""
|
||||||
return self.AddAllowedAssignments((target, var), ((int(base**exp), base) for base in range(upto + 1)))
|
return self.AddAllowedAssignments((target, var), ((int(base**exp), base) for base in range(upto + 1)))
|
||||||
|
|
||||||
def AddDivisionEquality(self, target, num, denom):
|
def AddDivisionEquality(self, target, num, denom) -> tuple:
|
||||||
"""Adds `target == num // denom` (integer division rounded towards 0).
|
"""Adds `target == num // denom` (integer division rounded towards 0).
|
||||||
|
|
||||||
Each parameter is interpreted as a BoundedLinearExpression, and a layer of indirection is applied such that each Constraint in the returned tuple can accept an enforcement literal."""
|
Each parameter is interpreted as a BoundedLinearExpression, and a layer of indirection is applied such that each Constraint in the returned tuple can accept an enforcement literal."""
|
||||||
intermediate_target, target_constraint = self.NewIntermediateIntVar(target, f'{repr(target)} == {repr(num)} // {repr(denom)}: target')
|
# Used for variable names
|
||||||
intermediate_num, num_constraint = self.NewIntermediateIntVar(num, f'{repr(target)} == {repr(num)} // {repr(denom)}: num', lb = 0)
|
invocation: Final[str] = f'{repr(target)} == {repr(num)} // {repr(denom)}'
|
||||||
intermediate_denom, denom_constraint = self.NewIntermediateIntVar(denom, f'{repr(target)} == {repr(num)} // {repr(denom)}: denom', lb = 1)
|
|
||||||
|
intermediate_target, target_constraint = self.NewIntermediateIntVar(target, f'{invocation}: target')
|
||||||
|
intermediate_num, num_constraint = self.NewIntermediateIntVar(num, f'{invocation}: num', lb=0)
|
||||||
|
intermediate_denom, denom_constraint = self.NewIntermediateIntVar(denom, f'{invocation}: denom', lb=1)
|
||||||
|
|
||||||
super().AddDivisionEquality(intermediate_target, intermediate_num, intermediate_denom)
|
super().AddDivisionEquality(intermediate_target, intermediate_num, intermediate_denom)
|
||||||
return (target_constraint, num_constraint, denom_constraint)
|
return (target_constraint, num_constraint, denom_constraint)
|
||||||
|
|
||||||
def AddIf(self, variable, *constraints):
|
def AddIf(self, variable, *constraints: tuple) -> frozenset:
|
||||||
"""Add constraints to the model, only enforced if the specified variable is true.
|
"""Add constraints to the model, only enforced if the specified variable is true.
|
||||||
|
|
||||||
Each item in `constraints` must be either a BoundedLinearExpression, a Constraint compatible with OnlyEnforceIf, a 0-arity partial method of CpModel returning a valid item, or an iterable containing valid items."""
|
Each item in `constraints` must be either a BoundedLinearExpression, a Constraint compatible with OnlyEnforceIf, a 0-arity partial method of CpModel returning a valid item, or an iterable containing valid items."""
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def Add(constraint):
|
def Add(constraint: Iterable) -> frozenset:
|
||||||
if constraint_iterator := iter(constraint):
|
return frozenset((Add(element) for element in constraint))
|
||||||
return frozenset((Add(element) for element in constraint_iterator))
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Invalid constraint: {repr(constraint)}")
|
|
||||||
|
|
||||||
@Add.register
|
@Add.register
|
||||||
def _(constraint: cp_model.Constraint):
|
def _(constraint: cp_model.Constraint) -> cp_model.Constraint:
|
||||||
return constraint.OnlyEnforceIf(variable)
|
return constraint.OnlyEnforceIf(variable)
|
||||||
|
|
||||||
@Add.register
|
@Add.register
|
||||||
def _(constraint: cp_model.BoundedLinearExpression):
|
def _(constraint: cp_model.BoundedLinearExpression) -> cp_model.Constraint:
|
||||||
return Add(self.Add(constraint))
|
return Add(self.Add(constraint))
|
||||||
|
|
||||||
@Add.register
|
@Add.register
|
||||||
|
@ -59,18 +65,18 @@ Each item in `constraints` must be either a BoundedLinearExpression, a Constrain
|
||||||
|
|
||||||
return frozenset((Add(constraint) for constraint in constraints))
|
return frozenset((Add(constraint) for constraint in constraints))
|
||||||
|
|
||||||
def AddMultiplicationEquality(self, target, variables):
|
def AddMultiplicationEquality(self, target, variables: Iterable) -> tuple:
|
||||||
"""Adds `target == variables[0] * .. * variables[n]`.
|
"""Adds `target == variables[0] * .. * variables[n]`.
|
||||||
|
|
||||||
Each parameter is interpreted as a BoundedLinearExpression, and a layer of indirection is applied such that each Constraint in the returned tuple can accept an enforcement literal."""
|
Each parameter is interpreted as a BoundedLinearExpression, and a layer of indirection is applied such that each Constraint in the returned tuple can accept an enforcement literal."""
|
||||||
|
|
||||||
superclass = super()
|
superclass: Final = super()
|
||||||
|
|
||||||
def Multiply(end, stack):
|
def Multiply(end, stack: list) -> tuple:
|
||||||
intermediate_variable, variable_constraint = self.NewIntermediateIntVar(stack.pop(), f'{repr(end)} == {"*".join((repr(variable) for variable in stack))}: last variable')
|
intermediate_variable, variable_constraint = self.NewIntermediateIntVar(stack.pop(), f'{repr(end)} == {"*".join((repr(variable) for variable in stack))}: last variable')
|
||||||
|
|
||||||
partial_target = self.NewIntVar(f'{repr(end)} == {"*".join((repr(variable) for variable in stack))}: partial target')
|
partial_target: Final[cp_model.IntVar] = self.NewIntVar(f'{repr(end)} == {"*".join((repr(variable) for variable in stack))}: partial target')
|
||||||
recursive_constraints = self.AddMultiplicationEquality(partial_target, stack) if len(stack) > 1 else (self.Add(partial_target == stack.pop()),)
|
recursive_constraints: Final[tuple] = self.AddMultiplicationEquality(partial_target, stack) if len(stack) > 1 else (self.Add(partial_target == stack.pop()),)
|
||||||
|
|
||||||
intermediate_target, target_constraint = self.NewIntermediateIntVar(end, f'{repr(end)} == {"*".join((repr(variable) for variable in stack))}: target')
|
intermediate_target, target_constraint = self.NewIntermediateIntVar(end, f'{repr(end)} == {"*".join((repr(variable) for variable in stack))}: target')
|
||||||
|
|
||||||
|
@ -82,29 +88,29 @@ Each parameter is interpreted as a BoundedLinearExpression, and a layer of indir
|
||||||
return Multiply(target, variables.copy() if isinstance(variables, list) else list(variables))
|
return Multiply(target, variables.copy() if isinstance(variables, list) else list(variables))
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def BoolExpression(self, bounded_linear_exp):
|
def BoolExpression(self, bounded_linear_exp: cp_model.BoundedLinearExpression) -> cp_model.IntVar:
|
||||||
"""Add a fully-reified implication using an intermediate Boolean variable."""
|
"""Add a fully-reified implication using an intermediate Boolean variable."""
|
||||||
|
|
||||||
intermediate = self.NewBoolVar(str(bounded_linear_exp))
|
intermediate: Final[cp_model.IntVar] = self.NewBoolVar(str(bounded_linear_exp))
|
||||||
linear_exp = bounded_linear_exp.Expression()
|
linear_exp: Final[cp_model.LinearExp] = bounded_linear_exp.Expression()
|
||||||
domain = cp_model.Domain(*bounded_linear_exp.Bounds())
|
domain: Final[cp_model.Domain] = cp_model.Domain(*bounded_linear_exp.Bounds())
|
||||||
self.AddLinearExpressionInDomain(linear_exp, domain).OnlyEnforceIf(intermediate)
|
self.AddLinearExpressionInDomain(linear_exp, domain).OnlyEnforceIf(intermediate)
|
||||||
self.AddLinearExpressionInDomain(linear_exp, domain.Complement()).OnlyEnforceIf(intermediate.Not())
|
self.AddLinearExpressionInDomain(linear_exp, domain.Complement()).OnlyEnforceIf(intermediate.Not())
|
||||||
return intermediate
|
return intermediate
|
||||||
|
|
||||||
@singledispatchmethod
|
@singledispatchmethod
|
||||||
def NewIntermediateIntVar(self, expression: cp_model.LinearExpr, name, *, lb = cp_model.INT32_MIN, ub = cp_model.INT32_MAX):
|
def NewIntermediateIntVar(self, expression: cp_model.LinearExpr, name: str, *, lb: Integral = cp_model.INT32_MIN, ub: Integral = cp_model.INT32_MAX) -> tuple[cp_model.IntVar, cp_model.Constraint]:
|
||||||
"""Creates an integer variable equivalent to the given expression and returns a tuple consisting of the variable and constraint for use with enforcement literals.
|
"""Creates an integer variable equivalent to the given expression and returns a tuple consisting of the variable and constraint for use with enforcement literals.
|
||||||
|
|
||||||
`equality` must be either a LinearExp or a unary partialmethod that accepts a target integer variable and returns Constraints."""
|
`equality` must be either a LinearExp or a unary partialmethod that accepts a target integer variable and returns Constraints."""
|
||||||
|
|
||||||
intermediate = super().NewIntVar(lb, ub, name)
|
intermediate: Final[cp_model.IntVar] = super().NewIntVar(lb, ub, name)
|
||||||
return (intermediate, self.Add(intermediate == expression))
|
return (intermediate, self.Add(intermediate == expression))
|
||||||
|
|
||||||
@NewIntermediateIntVar.register
|
@NewIntermediateIntVar.register
|
||||||
def _(self, expression: partialmethod, name, *, lb = cp_model.INT32_MIN, ub = cp_model.INT32_MAX):
|
def _(self, expression: partialmethod, name: str, *, lb: Integral = cp_model.INT32_MIN, ub: Integral = cp_model.INT32_MAX) -> tuple:
|
||||||
intermediate = super().NewIntVar(lb, ub, name)
|
intermediate: Final[cp_model.IntVar] = super().NewIntVar(lb, ub, name)
|
||||||
return (intermediate, expression.__get__(self)(intermediate))
|
return (intermediate, expression.__get__(self)(intermediate))
|
||||||
|
|
||||||
def NewIntVar(self, name, *, lb = cp_model.INT32_MIN, ub = cp_model.INT32_MAX):
|
def NewIntVar(self, name: str, *, lb: Integral = cp_model.INT32_MIN, ub: Integral = cp_model.INT32_MAX) -> cp_model.IntVar:
|
||||||
return super().NewIntVar(lb, ub, name)
|
return super().NewIntVar(lb, ub, name)
|
||||||
|
|
Loading…
Reference in New Issue