2021-09-22 22:06:44 +00:00
__author__ = " Jeremy Saklad "
2021-09-27 18:09:00 +00:00
from functools import cache , partialmethod , reduce , singledispatch
2021-09-22 22:06:44 +00:00
from ortools . sat . python import cp_model
class BoneMarketModel ( cp_model . CpModel ) :
""" A CpModel with additional functions for common constraints and enhanced enforcement literal support. """
__slots__ = ( )
def AddAllowedAssignments ( self , variables , tuples_list ) :
intermediate_variables , constraints = zip ( * ( self . NewIntermediateIntVar ( variable , f ' { repr ( ( variables , tuples_list ) ) } : { variable } ' ) for variable in variables ) )
super ( ) . AddAllowedAssignments ( intermediate_variables , tuples_list )
return constraints
def AddApproximateExponentiationEquality ( self , target , var , exp , upto ) :
""" Add an approximate exponentiation equality using a lookup table.
Set ` upto ` to a value that is unlikely to come into play .
Each parameter is interpreted as a BoundedLinearExpression , and a layer of indirection is applied such that each Constraint in the returned tuple can accept an enforcement literal . """
return self . AddAllowedAssignments ( ( target , var ) , ( ( int ( base * * exp ) , base ) for base in range ( upto + 1 ) ) )
2021-09-28 21:34:50 +00:00
def AddDivisionEquality ( self , target , num , denom ) :
""" Adds `target == num // denom` (integer division rounded towards 0).
Each parameter is interpreted as a BoundedLinearExpression , and a layer of indirection is applied such that each Constraint in the returned tuple can accept an enforcement literal . """
intermediate_target , target_constraint = self . NewIntermediateIntVar ( target , f ' { repr ( target ) } == { repr ( num ) } // { repr ( denom ) } : target ' )
intermediate_num , num_constraint = self . NewIntermediateIntVar ( num , f ' { repr ( target ) } == { repr ( num ) } // { repr ( denom ) } : num ' , lb = 0 )
intermediate_denom , denom_constraint = self . NewIntermediateIntVar ( denom , f ' { repr ( target ) } == { repr ( num ) } // { repr ( denom ) } : denom ' , lb = 0 )
super ( ) . AddDivisionEquality ( intermediate_target , intermediate_num , intermediate_denom )
return ( target_constraint , num_constraint , denom_constraint )
2021-09-27 18:09:00 +00:00
def AddIf ( self , variable , * constraints ) :
""" Add constraints to the model, only enforced if the specified variable is true.
Each item in ` constraints ` must be either a BoundedLinearExpression , a Constraint compatible with OnlyEnforceIf , a 0 - arity partial method of CpModel returning a valid item , or an iterable containing valid items . """
@singledispatch
def Add ( constraint ) :
if constraint_iterator := iter ( constraint ) :
return frozenset ( ( Add ( element ) for element in constraint_iterator ) )
else :
raise TypeError ( f " Invalid constraint: { repr ( constraint ) } " )
@Add.register
def _ ( constraint : cp_model . Constraint ) :
return constraint . OnlyEnforceIf ( variable )
@Add.register
def _ ( constraint : cp_model . BoundedLinearExpression ) :
return Add ( self . Add ( constraint ) )
@Add.register
def _ ( constraint : partialmethod ) :
return Add ( constraint . __get__ ( self ) ( ) )
return frozenset ( ( Add ( constraint ) for constraint in constraints ) )
2021-09-22 22:06:44 +00:00
def AddMultiplicationEquality ( self , target , variables ) :
""" Adds `target == variables[0] * .. * variables[n]`.
Each parameter is interpreted as a BoundedLinearExpression , and a layer of indirection is applied such that each Constraint in the returned tuple can accept an enforcement literal . """
superclass = super ( )
def Multiply ( end , stack ) :
intermediate_variable , variable_constraint = self . NewIntermediateIntVar ( stack . pop ( ) , f ' { repr ( end ) } == { " * " . join ( ( repr ( variable ) for variable in stack ) ) } : last variable ' )
partial_target = self . NewIntVar ( f ' { repr ( end ) } == { " * " . join ( ( repr ( variable ) for variable in stack ) ) } : partial target ' )
recursive_constraints = self . AddMultiplicationEquality ( partial_target , stack ) if len ( stack ) > 1 else ( self . Add ( partial_target == stack . pop ( ) ) , )
intermediate_target , target_constraint = self . NewIntermediateIntVar ( end , f ' { repr ( end ) } == { " * " . join ( ( repr ( variable ) for variable in stack ) ) } : target ' )
superclass . AddMultiplicationEquality ( intermediate_target , ( partial_target , intermediate_variable ) )
return ( variable_constraint , * recursive_constraints , target_constraint )
# Avoid mutating parameter directly
return Multiply ( target , variables . copy ( ) if isinstance ( variables , list ) else list ( variables ) )
2021-09-25 03:04:54 +00:00
@cache
def BoolExpression ( self , bounded_linear_exp ) :
2021-09-22 22:06:44 +00:00
""" Add a fully-reified implication using an intermediate Boolean variable. """
2021-09-25 03:04:54 +00:00
intermediate = self . NewBoolVar ( str ( bounded_linear_exp ) )
linear_exp = bounded_linear_exp . Expression ( )
domain = cp_model . Domain ( * bounded_linear_exp . Bounds ( ) )
2021-09-22 22:06:44 +00:00
self . AddLinearExpressionInDomain ( linear_exp , domain ) . OnlyEnforceIf ( intermediate )
self . AddLinearExpressionInDomain ( linear_exp , domain . Complement ( ) ) . OnlyEnforceIf ( intermediate . Not ( ) )
return intermediate
def NewIntermediateIntVar ( self , linear_exp , name , * , lb = cp_model . INT_MIN / / 8 , ub = cp_model . INT_MAX / / 8 ) :
""" Creates an integer variable equivalent to the given expression and returns a tuple consisting of the variable and constraint for use with enforcement literals. """
intermediate = super ( ) . NewIntVar ( lb , ub , name )
return ( intermediate , self . Add ( intermediate == linear_exp ) )
def NewIntVar ( self , name , * , lb = cp_model . INT32_MIN , ub = cp_model . INT32_MAX ) :
return super ( ) . NewIntVar ( lb , ub , name )