L2.dead_code_elimination

  1from functools import partial
  2
  3from .syntax import (
  4    Abstract,
  5    Allocate,
  6    Apply,
  7    Begin,
  8    Branch,
  9    Identifier,
 10    Immediate,
 11    Let,
 12    Load,
 13    Primitive,
 14    Reference,
 15    Store,
 16    Term,
 17)
 18from .util import (
 19    Context,
 20    recur_terms,
 21)
 22
 23
 24def is_pure(term: Term) -> bool:
 25    match term:
 26        case Immediate():
 27            return True
 28
 29        case Reference():
 30            return True
 31
 32        case Primitive(left=left, right=right):
 33            return is_pure(left) and is_pure(right)
 34
 35        case Abstract(body=body):
 36            return is_pure(body)
 37
 38        case Let(bindings=bindings, body=body):
 39            return all(is_pure(value) for _, value in bindings) and is_pure(body)
 40
 41        case Branch(left=left, right=right, consequent=consequent, otherwise=otherwise):
 42            return is_pure(left) and is_pure(right) and is_pure(consequent) and is_pure(otherwise)
 43
 44        case Load(base=base):
 45            return is_pure(base)
 46
 47        case Begin(effects=effects, value=value):
 48            return all(is_pure(effect) for effect in effects) and is_pure(value)
 49
 50        case Apply():
 51            return False
 52
 53        case Allocate():
 54            return False
 55
 56        case Store():
 57            return False
 58
 59
 60def free_vars(term: Term) -> set[Identifier]:
 61    match term:
 62        case Immediate():
 63            return set()
 64
 65        case Reference(name=name):
 66            return {name}
 67
 68        case Primitive(left=left, right=right):
 69            return free_vars(left) | free_vars(right)
 70
 71        case Apply(target=target, arguments=arguments):
 72            result = free_vars(target)
 73            for argument in arguments:
 74                result |= free_vars(argument)
 75            return result
 76
 77        case Abstract(parameters=parameters, body=body):
 78            return free_vars(body) - set(parameters)
 79
 80        case Branch(left=left, right=right, consequent=consequent, otherwise=otherwise):
 81            return free_vars(left) | free_vars(right) | free_vars(consequent) | free_vars(otherwise)
 82
 83        case Load(base=base):
 84            return free_vars(base)
 85
 86        case Store(base=base, value=value):
 87            return free_vars(base) | free_vars(value)
 88
 89        case Begin(effects=effects, value=value):
 90            result = free_vars(value)
 91            for effect in effects:
 92                result |= free_vars(effect)
 93            return result
 94
 95        case Allocate():
 96            return set()
 97
 98        case Let(bindings=bindings, body=body):
 99            names = [name for name, _ in bindings]
100            result = free_vars(body) - set(names)
101            for _, value in bindings:
102                result |= free_vars(value)
103            return result
104
105
106def dead_code_elimination_term(
107    term: Term,
108    context: Context,
109) -> Term:
110    recur = partial(dead_code_elimination_term, context=context)
111
112    match term:
113        case Let(bindings=bindings, body=body):
114            new_values = [(name, recur(value)) for name, value in bindings]
115            new_body = recur(body)
116
117            live = free_vars(new_body)
118            kept_reversed: list[tuple[Identifier, Term]] = []
119
120            for name, value in reversed(new_values):
121                if name in live or not is_pure(value):
122                    kept_reversed.append((name, value))
123                    live.discard(name)
124                    live |= free_vars(value)
125
126            kept = list(reversed(kept_reversed))
127            if len(kept) == 0:
128                return new_body
129
130            return Let(bindings=kept, body=new_body)
131
132        case Reference(name=_name):
133            return term
134
135        case Abstract(parameters=parameters, body=body):
136            return Abstract(parameters=parameters, body=recur(body))
137
138        case Apply(target=target, arguments=arguments):
139            return Apply(
140                target=recur(target),
141                arguments=recur_terms(arguments, recur),
142            )
143
144        case Immediate(value=_value):
145            return term
146
147        case Primitive(operator=operator, left=left, right=right):
148            return Primitive(
149                operator=operator,
150                left=recur(left),
151                right=recur(right),
152            )
153
154        case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
155            return Branch(
156                operator=operator,
157                left=recur(left),
158                right=recur(right),
159                consequent=recur(consequent),
160                otherwise=recur(otherwise),
161            )
162
163        case Allocate(count=count):
164            return Allocate(count=count)
165
166        case Load(base=base, index=index):
167            return Load(base=recur(base), index=index)
168
169        case Store(base=base, index=index, value=value):
170            return Store(base=recur(base), index=index, value=recur(value))
171
172        case Begin(effects=effects, value=value):  # pragma: no branch
173            new_effects = [recur(effect) for effect in effects]
174            kept_effects = [effect for effect in new_effects if not is_pure(effect)]
175            new_value = recur(value)
176
177            if len(kept_effects) == 0:
178                return new_value
179
180            return Begin(effects=kept_effects, value=new_value)
def is_pure(term: Term) -> bool:
25def is_pure(term: Term) -> bool:
26    match term:
27        case Immediate():
28            return True
29
30        case Reference():
31            return True
32
33        case Primitive(left=left, right=right):
34            return is_pure(left) and is_pure(right)
35
36        case Abstract(body=body):
37            return is_pure(body)
38
39        case Let(bindings=bindings, body=body):
40            return all(is_pure(value) for _, value in bindings) and is_pure(body)
41
42        case Branch(left=left, right=right, consequent=consequent, otherwise=otherwise):
43            return is_pure(left) and is_pure(right) and is_pure(consequent) and is_pure(otherwise)
44
45        case Load(base=base):
46            return is_pure(base)
47
48        case Begin(effects=effects, value=value):
49            return all(is_pure(effect) for effect in effects) and is_pure(value)
50
51        case Apply():
52            return False
53
54        case Allocate():
55            return False
56
57        case Store():
58            return False
def free_vars(term: Term) -> set[Identifier]:
 61def free_vars(term: Term) -> set[Identifier]:
 62    match term:
 63        case Immediate():
 64            return set()
 65
 66        case Reference(name=name):
 67            return {name}
 68
 69        case Primitive(left=left, right=right):
 70            return free_vars(left) | free_vars(right)
 71
 72        case Apply(target=target, arguments=arguments):
 73            result = free_vars(target)
 74            for argument in arguments:
 75                result |= free_vars(argument)
 76            return result
 77
 78        case Abstract(parameters=parameters, body=body):
 79            return free_vars(body) - set(parameters)
 80
 81        case Branch(left=left, right=right, consequent=consequent, otherwise=otherwise):
 82            return free_vars(left) | free_vars(right) | free_vars(consequent) | free_vars(otherwise)
 83
 84        case Load(base=base):
 85            return free_vars(base)
 86
 87        case Store(base=base, value=value):
 88            return free_vars(base) | free_vars(value)
 89
 90        case Begin(effects=effects, value=value):
 91            result = free_vars(value)
 92            for effect in effects:
 93                result |= free_vars(effect)
 94            return result
 95
 96        case Allocate():
 97            return set()
 98
 99        case Let(bindings=bindings, body=body):
100            names = [name for name, _ in bindings]
101            result = free_vars(body) - set(names)
102            for _, value in bindings:
103                result |= free_vars(value)
104            return result
def dead_code_elimination_term(term: Term, context: Context) -> Term:
107def dead_code_elimination_term(
108    term: Term,
109    context: Context,
110) -> Term:
111    recur = partial(dead_code_elimination_term, context=context)
112
113    match term:
114        case Let(bindings=bindings, body=body):
115            new_values = [(name, recur(value)) for name, value in bindings]
116            new_body = recur(body)
117
118            live = free_vars(new_body)
119            kept_reversed: list[tuple[Identifier, Term]] = []
120
121            for name, value in reversed(new_values):
122                if name in live or not is_pure(value):
123                    kept_reversed.append((name, value))
124                    live.discard(name)
125                    live |= free_vars(value)
126
127            kept = list(reversed(kept_reversed))
128            if len(kept) == 0:
129                return new_body
130
131            return Let(bindings=kept, body=new_body)
132
133        case Reference(name=_name):
134            return term
135
136        case Abstract(parameters=parameters, body=body):
137            return Abstract(parameters=parameters, body=recur(body))
138
139        case Apply(target=target, arguments=arguments):
140            return Apply(
141                target=recur(target),
142                arguments=recur_terms(arguments, recur),
143            )
144
145        case Immediate(value=_value):
146            return term
147
148        case Primitive(operator=operator, left=left, right=right):
149            return Primitive(
150                operator=operator,
151                left=recur(left),
152                right=recur(right),
153            )
154
155        case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
156            return Branch(
157                operator=operator,
158                left=recur(left),
159                right=recur(right),
160                consequent=recur(consequent),
161                otherwise=recur(otherwise),
162            )
163
164        case Allocate(count=count):
165            return Allocate(count=count)
166
167        case Load(base=base, index=index):
168            return Load(base=recur(base), index=index)
169
170        case Store(base=base, index=index, value=value):
171            return Store(base=recur(base), index=index, value=recur(value))
172
173        case Begin(effects=effects, value=value):  # pragma: no branch
174            new_effects = [recur(effect) for effect in effects]
175            kept_effects = [effect for effect in new_effects if not is_pure(effect)]
176            new_value = recur(value)
177
178            if len(kept_effects) == 0:
179                return new_value
180
181            return Begin(effects=kept_effects, value=new_value)