L2.constant_folding

  1from functools import partial
  2
  3from .syntax import (
  4    Abstract,
  5    Allocate,
  6    Apply,
  7    Begin,
  8    Branch,
  9    Immediate,
 10    Let,
 11    Load,
 12    Primitive,
 13    Reference,
 14    Store,
 15    Term,
 16)
 17from .util import (
 18    Context,
 19    extend_context_with_bindings,
 20    recur_terms,
 21)
 22
 23
 24def _normalize_commutative_immediate_left(
 25    operator: str,
 26    left: Term,
 27    right: Term,
 28) -> Primitive:
 29    return Primitive(
 30        operator=operator,  # type: ignore[arg-type]
 31        left=right,
 32        right=left,
 33    )
 34
 35
 36def constant_folding_term(
 37    term: Term,
 38    context: Context,
 39) -> Term:
 40    recur = partial(constant_folding_term, context=context)  # noqa: F841
 41
 42    match term:
 43        case Let(bindings=bindings, body=body):
 44            new_bindings, new_context = extend_context_with_bindings(bindings, context, recur)
 45            return Let(
 46                bindings=new_bindings,
 47                body=constant_folding_term(body, new_context),
 48            )
 49
 50        case Reference(name=name):
 51            if name in context:
 52                return context[name]
 53            return term
 54
 55        case Abstract(parameters=parameters, body=body):
 56            return Abstract(parameters=parameters, body=recur(body))
 57
 58        case Apply(target=target, arguments=arguments):
 59            return Apply(
 60                target=recur(target),
 61                arguments=recur_terms(arguments, recur),
 62            )
 63
 64        case Immediate(value=_value):
 65            return term
 66
 67        case Primitive(operator=operator, left=left, right=right):
 68            match operator:
 69                case "+":
 70                    match recur(left), recur(right):
 71                        case Immediate(value=i1), Immediate(value=i2):
 72                            return Immediate(value=i1 + i2)
 73
 74                        case Immediate(value=0), right:
 75                            return right
 76
 77                        case [
 78                            Primitive(operator="+", left=Immediate(value=i1), right=left),
 79                            Primitive(operator="+", left=Immediate(value=i2), right=right),
 80                        ]:
 81                            return Primitive(
 82                                operator="+",
 83                                left=Immediate(value=i1 + i2),
 84                                right=Primitive(
 85                                    operator="+",
 86                                    left=left,
 87                                    right=right,
 88                                ),
 89                            )
 90
 91                        case left, Immediate() as right:
 92                            return _normalize_commutative_immediate_left("+", left, right)
 93
 94                        # Coverage reports a synthetic exit arc on this fallback match arm.
 95                        # The arm is intentionally reachable and returns the non-folded primitive.
 96                        case left, right:  # pragma: no branch
 97                            return Primitive(
 98                                operator="+",
 99                                left=left,
100                                right=right,
101                            )
102
103                case "-":
104                    match recur(left), recur(right):
105                        case Immediate(value=i1), Immediate(value=i2):
106                            return Immediate(value=i1 - i2)
107
108                        # Coverage reports a synthetic exit arc on this fallback match arm.
109                        # The arm is intentionally reachable and returns the non-folded primitive.
110                        case left, right:  # pragma: no branch
111                            return Primitive(operator="-", left=left, right=right)
112
113                # Coverage may report an extra arc on this literal case label under pattern matching.
114                # Runtime terms validated by the syntax model still follow normal folding logic below.
115                case "*":  # pragma: no branch
116                    match recur(left), recur(right):
117                        case Immediate(value=i1), Immediate(value=i2):
118                            return Immediate(value=i1 * i2)
119
120                        case Immediate(value=0), _:
121                            return Immediate(value=0)
122
123                        case _, Immediate(value=0):
124                            return Immediate(value=0)
125
126                        case Immediate(value=1), right:
127                            return right
128
129                        case left, Immediate(value=1):
130                            return left
131
132                        case left, Immediate() as right:
133                            return _normalize_commutative_immediate_left("*", left, right)
134
135                        # Coverage reports a synthetic exit arc on this fallback match arm.
136                        # The arm is intentionally reachable and returns the non-folded primitive.
137                        case left, right:  # pragma: no branch
138                            return Primitive(operator="*", left=left, right=right)
139
140        case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
141            folded_left = recur(left)
142            folded_right = recur(right)
143            folded_consequent = recur(consequent)
144            folded_otherwise = recur(otherwise)
145            match operator:
146                case "<":
147                    match folded_left, folded_right:
148                        case Immediate(value=i1), Immediate(value=i2):
149                            return folded_consequent if i1 < i2 else folded_otherwise
150                        case _:
151                            pass
152                # Coverage may report an extra arc on this literal case label under pattern matching.
153                # Runtime terms validated by the syntax model use only "<" and "==".
154                case "==":  # pragma: no branch
155                    match folded_left, folded_right:
156                        case Immediate(value=i1), Immediate(value=i2):
157                            return folded_consequent if i1 == i2 else folded_otherwise
158                        case _:
159                            pass
160            return Branch(
161                operator=operator,
162                left=folded_left,
163                right=folded_right,
164                consequent=folded_consequent,
165                otherwise=folded_otherwise,
166            )
167
168        case Allocate(count=count):
169            return Allocate(count=count)
170
171        case Load(base=base, index=index):
172            return Load(base=recur(base), index=index)
173
174        case Store(base=base, index=index, value=value):
175            return Store(base=recur(base), index=index, value=recur(value))
176
177        # Coverage may report an extra structural arc for this match arm.
178        # Semantically this always returns the reconstructed Begin node.
179        case Begin(effects=effects, value=value):  # pragma: no branch
180            return Begin(effects=recur_terms(effects, recur), value=recur(value))
def constant_folding_term(term: Term, context: Context) -> Term:
 37def constant_folding_term(
 38    term: Term,
 39    context: Context,
 40) -> Term:
 41    recur = partial(constant_folding_term, context=context)  # noqa: F841
 42
 43    match term:
 44        case Let(bindings=bindings, body=body):
 45            new_bindings, new_context = extend_context_with_bindings(bindings, context, recur)
 46            return Let(
 47                bindings=new_bindings,
 48                body=constant_folding_term(body, new_context),
 49            )
 50
 51        case Reference(name=name):
 52            if name in context:
 53                return context[name]
 54            return term
 55
 56        case Abstract(parameters=parameters, body=body):
 57            return Abstract(parameters=parameters, body=recur(body))
 58
 59        case Apply(target=target, arguments=arguments):
 60            return Apply(
 61                target=recur(target),
 62                arguments=recur_terms(arguments, recur),
 63            )
 64
 65        case Immediate(value=_value):
 66            return term
 67
 68        case Primitive(operator=operator, left=left, right=right):
 69            match operator:
 70                case "+":
 71                    match recur(left), recur(right):
 72                        case Immediate(value=i1), Immediate(value=i2):
 73                            return Immediate(value=i1 + i2)
 74
 75                        case Immediate(value=0), right:
 76                            return right
 77
 78                        case [
 79                            Primitive(operator="+", left=Immediate(value=i1), right=left),
 80                            Primitive(operator="+", left=Immediate(value=i2), right=right),
 81                        ]:
 82                            return Primitive(
 83                                operator="+",
 84                                left=Immediate(value=i1 + i2),
 85                                right=Primitive(
 86                                    operator="+",
 87                                    left=left,
 88                                    right=right,
 89                                ),
 90                            )
 91
 92                        case left, Immediate() as right:
 93                            return _normalize_commutative_immediate_left("+", left, right)
 94
 95                        # Coverage reports a synthetic exit arc on this fallback match arm.
 96                        # The arm is intentionally reachable and returns the non-folded primitive.
 97                        case left, right:  # pragma: no branch
 98                            return Primitive(
 99                                operator="+",
100                                left=left,
101                                right=right,
102                            )
103
104                case "-":
105                    match recur(left), recur(right):
106                        case Immediate(value=i1), Immediate(value=i2):
107                            return Immediate(value=i1 - i2)
108
109                        # Coverage reports a synthetic exit arc on this fallback match arm.
110                        # The arm is intentionally reachable and returns the non-folded primitive.
111                        case left, right:  # pragma: no branch
112                            return Primitive(operator="-", left=left, right=right)
113
114                # Coverage may report an extra arc on this literal case label under pattern matching.
115                # Runtime terms validated by the syntax model still follow normal folding logic below.
116                case "*":  # pragma: no branch
117                    match recur(left), recur(right):
118                        case Immediate(value=i1), Immediate(value=i2):
119                            return Immediate(value=i1 * i2)
120
121                        case Immediate(value=0), _:
122                            return Immediate(value=0)
123
124                        case _, Immediate(value=0):
125                            return Immediate(value=0)
126
127                        case Immediate(value=1), right:
128                            return right
129
130                        case left, Immediate(value=1):
131                            return left
132
133                        case left, Immediate() as right:
134                            return _normalize_commutative_immediate_left("*", left, right)
135
136                        # Coverage reports a synthetic exit arc on this fallback match arm.
137                        # The arm is intentionally reachable and returns the non-folded primitive.
138                        case left, right:  # pragma: no branch
139                            return Primitive(operator="*", left=left, right=right)
140
141        case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
142            folded_left = recur(left)
143            folded_right = recur(right)
144            folded_consequent = recur(consequent)
145            folded_otherwise = recur(otherwise)
146            match operator:
147                case "<":
148                    match folded_left, folded_right:
149                        case Immediate(value=i1), Immediate(value=i2):
150                            return folded_consequent if i1 < i2 else folded_otherwise
151                        case _:
152                            pass
153                # Coverage may report an extra arc on this literal case label under pattern matching.
154                # Runtime terms validated by the syntax model use only "<" and "==".
155                case "==":  # pragma: no branch
156                    match folded_left, folded_right:
157                        case Immediate(value=i1), Immediate(value=i2):
158                            return folded_consequent if i1 == i2 else folded_otherwise
159                        case _:
160                            pass
161            return Branch(
162                operator=operator,
163                left=folded_left,
164                right=folded_right,
165                consequent=folded_consequent,
166                otherwise=folded_otherwise,
167            )
168
169        case Allocate(count=count):
170            return Allocate(count=count)
171
172        case Load(base=base, index=index):
173            return Load(base=recur(base), index=index)
174
175        case Store(base=base, index=index, value=value):
176            return Store(base=recur(base), index=index, value=recur(value))
177
178        # Coverage may report an extra structural arc for this match arm.
179        # Semantically this always returns the reconstructed Begin node.
180        case Begin(effects=effects, value=value):  # pragma: no branch
181            return Begin(effects=recur_terms(effects, recur), value=recur(value))