L2.constant_folding
1from functools import partial 2 3from .syntax import ( 4 Abstract, 5 Allocate, 6 Apply, 7 Begin, 8 Branch, 9 Immediate, 10 Let, 11 Load, 12 Primitive, 13 Reference, 14 Store, 15 Term, 16) 17from .util import ( 18 Context, 19 extend_context_with_bindings, 20 recur_terms, 21) 22 23 24def _normalize_commutative_immediate_left( 25 operator: str, 26 left: Term, 27 right: Term, 28) -> Primitive: 29 return Primitive( 30 operator=operator, # type: ignore[arg-type] 31 left=right, 32 right=left, 33 ) 34 35 36def constant_folding_term( 37 term: Term, 38 context: Context, 39) -> Term: 40 recur = partial(constant_folding_term, context=context) # noqa: F841 41 42 match term: 43 case Let(bindings=bindings, body=body): 44 new_bindings, new_context = extend_context_with_bindings(bindings, context, recur) 45 return Let( 46 bindings=new_bindings, 47 body=constant_folding_term(body, new_context), 48 ) 49 50 case Reference(name=name): 51 if name in context: 52 return context[name] 53 return term 54 55 case Abstract(parameters=parameters, body=body): 56 return Abstract(parameters=parameters, body=recur(body)) 57 58 case Apply(target=target, arguments=arguments): 59 return Apply( 60 target=recur(target), 61 arguments=recur_terms(arguments, recur), 62 ) 63 64 case Immediate(value=_value): 65 return term 66 67 case Primitive(operator=operator, left=left, right=right): 68 match operator: 69 case "+": 70 match recur(left), recur(right): 71 case Immediate(value=i1), Immediate(value=i2): 72 return Immediate(value=i1 + i2) 73 74 case Immediate(value=0), right: 75 return right 76 77 case [ 78 Primitive(operator="+", left=Immediate(value=i1), right=left), 79 Primitive(operator="+", left=Immediate(value=i2), right=right), 80 ]: 81 return Primitive( 82 operator="+", 83 left=Immediate(value=i1 + i2), 84 right=Primitive( 85 operator="+", 86 left=left, 87 right=right, 88 ), 89 ) 90 91 case left, Immediate() as right: 92 return _normalize_commutative_immediate_left("+", left, right) 93 94 # Coverage reports a synthetic exit arc on this fallback match arm. 95 # The arm is intentionally reachable and returns the non-folded primitive. 96 case left, right: # pragma: no branch 97 return Primitive( 98 operator="+", 99 left=left, 100 right=right, 101 ) 102 103 case "-": 104 match recur(left), recur(right): 105 case Immediate(value=i1), Immediate(value=i2): 106 return Immediate(value=i1 - i2) 107 108 # Coverage reports a synthetic exit arc on this fallback match arm. 109 # The arm is intentionally reachable and returns the non-folded primitive. 110 case left, right: # pragma: no branch 111 return Primitive(operator="-", left=left, right=right) 112 113 # Coverage may report an extra arc on this literal case label under pattern matching. 114 # Runtime terms validated by the syntax model still follow normal folding logic below. 115 case "*": # pragma: no branch 116 match recur(left), recur(right): 117 case Immediate(value=i1), Immediate(value=i2): 118 return Immediate(value=i1 * i2) 119 120 case Immediate(value=0), _: 121 return Immediate(value=0) 122 123 case _, Immediate(value=0): 124 return Immediate(value=0) 125 126 case Immediate(value=1), right: 127 return right 128 129 case left, Immediate(value=1): 130 return left 131 132 case left, Immediate() as right: 133 return _normalize_commutative_immediate_left("*", left, right) 134 135 # Coverage reports a synthetic exit arc on this fallback match arm. 136 # The arm is intentionally reachable and returns the non-folded primitive. 137 case left, right: # pragma: no branch 138 return Primitive(operator="*", left=left, right=right) 139 140 case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise): 141 folded_left = recur(left) 142 folded_right = recur(right) 143 folded_consequent = recur(consequent) 144 folded_otherwise = recur(otherwise) 145 match operator: 146 case "<": 147 match folded_left, folded_right: 148 case Immediate(value=i1), Immediate(value=i2): 149 return folded_consequent if i1 < i2 else folded_otherwise 150 case _: 151 pass 152 # Coverage may report an extra arc on this literal case label under pattern matching. 153 # Runtime terms validated by the syntax model use only "<" and "==". 154 case "==": # pragma: no branch 155 match folded_left, folded_right: 156 case Immediate(value=i1), Immediate(value=i2): 157 return folded_consequent if i1 == i2 else folded_otherwise 158 case _: 159 pass 160 return Branch( 161 operator=operator, 162 left=folded_left, 163 right=folded_right, 164 consequent=folded_consequent, 165 otherwise=folded_otherwise, 166 ) 167 168 case Allocate(count=count): 169 return Allocate(count=count) 170 171 case Load(base=base, index=index): 172 return Load(base=recur(base), index=index) 173 174 case Store(base=base, index=index, value=value): 175 return Store(base=recur(base), index=index, value=recur(value)) 176 177 # Coverage may report an extra structural arc for this match arm. 178 # Semantically this always returns the reconstructed Begin node. 179 case Begin(effects=effects, value=value): # pragma: no branch 180 return Begin(effects=recur_terms(effects, recur), value=recur(value))
def
constant_folding_term(term: Term, context: Context) -> Term:
37def constant_folding_term( 38 term: Term, 39 context: Context, 40) -> Term: 41 recur = partial(constant_folding_term, context=context) # noqa: F841 42 43 match term: 44 case Let(bindings=bindings, body=body): 45 new_bindings, new_context = extend_context_with_bindings(bindings, context, recur) 46 return Let( 47 bindings=new_bindings, 48 body=constant_folding_term(body, new_context), 49 ) 50 51 case Reference(name=name): 52 if name in context: 53 return context[name] 54 return term 55 56 case Abstract(parameters=parameters, body=body): 57 return Abstract(parameters=parameters, body=recur(body)) 58 59 case Apply(target=target, arguments=arguments): 60 return Apply( 61 target=recur(target), 62 arguments=recur_terms(arguments, recur), 63 ) 64 65 case Immediate(value=_value): 66 return term 67 68 case Primitive(operator=operator, left=left, right=right): 69 match operator: 70 case "+": 71 match recur(left), recur(right): 72 case Immediate(value=i1), Immediate(value=i2): 73 return Immediate(value=i1 + i2) 74 75 case Immediate(value=0), right: 76 return right 77 78 case [ 79 Primitive(operator="+", left=Immediate(value=i1), right=left), 80 Primitive(operator="+", left=Immediate(value=i2), right=right), 81 ]: 82 return Primitive( 83 operator="+", 84 left=Immediate(value=i1 + i2), 85 right=Primitive( 86 operator="+", 87 left=left, 88 right=right, 89 ), 90 ) 91 92 case left, Immediate() as right: 93 return _normalize_commutative_immediate_left("+", left, right) 94 95 # Coverage reports a synthetic exit arc on this fallback match arm. 96 # The arm is intentionally reachable and returns the non-folded primitive. 97 case left, right: # pragma: no branch 98 return Primitive( 99 operator="+", 100 left=left, 101 right=right, 102 ) 103 104 case "-": 105 match recur(left), recur(right): 106 case Immediate(value=i1), Immediate(value=i2): 107 return Immediate(value=i1 - i2) 108 109 # Coverage reports a synthetic exit arc on this fallback match arm. 110 # The arm is intentionally reachable and returns the non-folded primitive. 111 case left, right: # pragma: no branch 112 return Primitive(operator="-", left=left, right=right) 113 114 # Coverage may report an extra arc on this literal case label under pattern matching. 115 # Runtime terms validated by the syntax model still follow normal folding logic below. 116 case "*": # pragma: no branch 117 match recur(left), recur(right): 118 case Immediate(value=i1), Immediate(value=i2): 119 return Immediate(value=i1 * i2) 120 121 case Immediate(value=0), _: 122 return Immediate(value=0) 123 124 case _, Immediate(value=0): 125 return Immediate(value=0) 126 127 case Immediate(value=1), right: 128 return right 129 130 case left, Immediate(value=1): 131 return left 132 133 case left, Immediate() as right: 134 return _normalize_commutative_immediate_left("*", left, right) 135 136 # Coverage reports a synthetic exit arc on this fallback match arm. 137 # The arm is intentionally reachable and returns the non-folded primitive. 138 case left, right: # pragma: no branch 139 return Primitive(operator="*", left=left, right=right) 140 141 case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise): 142 folded_left = recur(left) 143 folded_right = recur(right) 144 folded_consequent = recur(consequent) 145 folded_otherwise = recur(otherwise) 146 match operator: 147 case "<": 148 match folded_left, folded_right: 149 case Immediate(value=i1), Immediate(value=i2): 150 return folded_consequent if i1 < i2 else folded_otherwise 151 case _: 152 pass 153 # Coverage may report an extra arc on this literal case label under pattern matching. 154 # Runtime terms validated by the syntax model use only "<" and "==". 155 case "==": # pragma: no branch 156 match folded_left, folded_right: 157 case Immediate(value=i1), Immediate(value=i2): 158 return folded_consequent if i1 == i2 else folded_otherwise 159 case _: 160 pass 161 return Branch( 162 operator=operator, 163 left=folded_left, 164 right=folded_right, 165 consequent=folded_consequent, 166 otherwise=folded_otherwise, 167 ) 168 169 case Allocate(count=count): 170 return Allocate(count=count) 171 172 case Load(base=base, index=index): 173 return Load(base=recur(base), index=index) 174 175 case Store(base=base, index=index, value=value): 176 return Store(base=recur(base), index=index, value=recur(value)) 177 178 # Coverage may report an extra structural arc for this match arm. 179 # Semantically this always returns the reconstructed Begin node. 180 case Begin(effects=effects, value=value): # pragma: no branch 181 return Begin(effects=recur_terms(effects, recur), value=recur(value))