L2.constant_propogation

 1from functools import partial
 2
 3from .syntax import (
 4    Abstract,
 5    Allocate,
 6    Apply,
 7    Begin,
 8    Branch,
 9    Immediate,
10    Let,
11    Load,
12    Primitive,
13    Reference,
14    Store,
15    Term,
16)
17from .util import (
18    Context,
19    extend_context_with_bindings,
20    recur_terms,
21)
22
23
24def constant_propogation_term(
25    term: Term,
26    context: Context,
27) -> Term:
28    recur = partial(constant_propogation_term, context=context)
29
30    match term:
31        case Let(bindings=bindings, body=body):
32            new_bindings, new_context = extend_context_with_bindings(bindings, context, recur)
33            return Let(
34                bindings=new_bindings,
35                body=constant_propogation_term(body, new_context),
36            )
37
38        case Reference(name=name):
39            if name in context:
40                return context[name]
41            return term
42
43        case Abstract(parameters=parameters, body=body):
44            abstract_context = {name: value for name, value in context.items() if name not in parameters}
45            return Abstract(
46                parameters=parameters,
47                body=constant_propogation_term(body, abstract_context),
48            )
49
50        case Apply(target=target, arguments=arguments):
51            return Apply(
52                target=recur(target),
53                arguments=recur_terms(arguments, recur),
54            )
55
56        case Immediate(value=_value):
57            return term
58
59        case Primitive(operator=operator, left=left, right=right):
60            return Primitive(
61                operator=operator,
62                left=recur(left),
63                right=recur(right),
64            )
65
66        case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
67            return Branch(
68                operator=operator,
69                left=recur(left),
70                right=recur(right),
71                consequent=recur(consequent),
72                otherwise=recur(otherwise),
73            )
74
75        case Allocate(count=count):
76            return Allocate(count=count)
77
78        case Load(base=base, index=index):
79            return Load(base=recur(base), index=index)
80
81        case Store(base=base, index=index, value=value):
82            return Store(base=recur(base), index=index, value=recur(value))
83
84        case Begin(effects=effects, value=value):  # pragma: no branch
85            return Begin(
86                effects=recur_terms(effects, recur),
87                value=recur(value),
88            )
def constant_propogation_term(term: Term, context: Context) -> Term:
25def constant_propogation_term(
26    term: Term,
27    context: Context,
28) -> Term:
29    recur = partial(constant_propogation_term, context=context)
30
31    match term:
32        case Let(bindings=bindings, body=body):
33            new_bindings, new_context = extend_context_with_bindings(bindings, context, recur)
34            return Let(
35                bindings=new_bindings,
36                body=constant_propogation_term(body, new_context),
37            )
38
39        case Reference(name=name):
40            if name in context:
41                return context[name]
42            return term
43
44        case Abstract(parameters=parameters, body=body):
45            abstract_context = {name: value for name, value in context.items() if name not in parameters}
46            return Abstract(
47                parameters=parameters,
48                body=constant_propogation_term(body, abstract_context),
49            )
50
51        case Apply(target=target, arguments=arguments):
52            return Apply(
53                target=recur(target),
54                arguments=recur_terms(arguments, recur),
55            )
56
57        case Immediate(value=_value):
58            return term
59
60        case Primitive(operator=operator, left=left, right=right):
61            return Primitive(
62                operator=operator,
63                left=recur(left),
64                right=recur(right),
65            )
66
67        case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
68            return Branch(
69                operator=operator,
70                left=recur(left),
71                right=recur(right),
72                consequent=recur(consequent),
73                otherwise=recur(otherwise),
74            )
75
76        case Allocate(count=count):
77            return Allocate(count=count)
78
79        case Load(base=base, index=index):
80            return Load(base=recur(base), index=index)
81
82        case Store(base=base, index=index, value=value):
83            return Store(base=recur(base), index=index, value=recur(value))
84
85        case Begin(effects=effects, value=value):  # pragma: no branch
86            return Begin(
87                effects=recur_terms(effects, recur),
88                value=recur(value),
89            )