L2.dead_code_elimination
1from functools import partial 2 3from .syntax import ( 4 Abstract, 5 Allocate, 6 Apply, 7 Begin, 8 Branch, 9 Identifier, 10 Immediate, 11 Let, 12 Load, 13 Primitive, 14 Reference, 15 Store, 16 Term, 17) 18from .util import ( 19 Context, 20 recur_terms, 21) 22 23 24def is_pure(term: Term) -> bool: 25 match term: 26 case Immediate(): 27 return True 28 29 case Reference(): 30 return True 31 32 case Primitive(left=left, right=right): 33 return is_pure(left) and is_pure(right) 34 35 case Abstract(body=body): 36 return is_pure(body) 37 38 case Let(bindings=bindings, body=body): 39 return all(is_pure(value) for _, value in bindings) and is_pure(body) 40 41 case Branch(left=left, right=right, consequent=consequent, otherwise=otherwise): 42 return is_pure(left) and is_pure(right) and is_pure(consequent) and is_pure(otherwise) 43 44 case Load(base=base): 45 return is_pure(base) 46 47 case Begin(effects=effects, value=value): 48 return all(is_pure(effect) for effect in effects) and is_pure(value) 49 50 case Apply(): 51 return False 52 53 case Allocate(): 54 return False 55 56 case Store(): 57 return False 58 59 60def free_vars(term: Term) -> set[Identifier]: 61 match term: 62 case Immediate(): 63 return set() 64 65 case Reference(name=name): 66 return {name} 67 68 case Primitive(left=left, right=right): 69 return free_vars(left) | free_vars(right) 70 71 case Apply(target=target, arguments=arguments): 72 result = free_vars(target) 73 for argument in arguments: 74 result |= free_vars(argument) 75 return result 76 77 case Abstract(parameters=parameters, body=body): 78 return free_vars(body) - set(parameters) 79 80 case Branch(left=left, right=right, consequent=consequent, otherwise=otherwise): 81 return free_vars(left) | free_vars(right) | free_vars(consequent) | free_vars(otherwise) 82 83 case Load(base=base): 84 return free_vars(base) 85 86 case Store(base=base, value=value): 87 return free_vars(base) | free_vars(value) 88 89 case Begin(effects=effects, value=value): 90 result = free_vars(value) 91 for effect in effects: 92 result |= free_vars(effect) 93 return result 94 95 case Allocate(): 96 return set() 97 98 case Let(bindings=bindings, body=body): 99 names = [name for name, _ in bindings] 100 result = free_vars(body) - set(names) 101 for _, value in bindings: 102 result |= free_vars(value) 103 return result 104 105 106def dead_code_elimination_term( 107 term: Term, 108 context: Context, 109) -> Term: 110 recur = partial(dead_code_elimination_term, context=context) 111 112 match term: 113 case Let(bindings=bindings, body=body): 114 new_values = [(name, recur(value)) for name, value in bindings] 115 new_body = recur(body) 116 117 live = free_vars(new_body) 118 kept_reversed: list[tuple[Identifier, Term]] = [] 119 120 for name, value in reversed(new_values): 121 if name in live or not is_pure(value): 122 kept_reversed.append((name, value)) 123 live.discard(name) 124 live |= free_vars(value) 125 126 kept = list(reversed(kept_reversed)) 127 if len(kept) == 0: 128 return new_body 129 130 return Let(bindings=kept, body=new_body) 131 132 case Reference(name=_name): 133 return term 134 135 case Abstract(parameters=parameters, body=body): 136 return Abstract(parameters=parameters, body=recur(body)) 137 138 case Apply(target=target, arguments=arguments): 139 return Apply( 140 target=recur(target), 141 arguments=recur_terms(arguments, recur), 142 ) 143 144 case Immediate(value=_value): 145 return term 146 147 case Primitive(operator=operator, left=left, right=right): 148 return Primitive( 149 operator=operator, 150 left=recur(left), 151 right=recur(right), 152 ) 153 154 case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise): 155 return Branch( 156 operator=operator, 157 left=recur(left), 158 right=recur(right), 159 consequent=recur(consequent), 160 otherwise=recur(otherwise), 161 ) 162 163 case Allocate(count=count): 164 return Allocate(count=count) 165 166 case Load(base=base, index=index): 167 return Load(base=recur(base), index=index) 168 169 case Store(base=base, index=index, value=value): 170 return Store(base=recur(base), index=index, value=recur(value)) 171 172 case Begin(effects=effects, value=value): # pragma: no branch 173 new_effects = [recur(effect) for effect in effects] 174 kept_effects = [effect for effect in new_effects if not is_pure(effect)] 175 new_value = recur(value) 176 177 if len(kept_effects) == 0: 178 return new_value 179 180 return Begin(effects=kept_effects, value=new_value)
def
is_pure(term: Term) -> bool:
25def is_pure(term: Term) -> bool: 26 match term: 27 case Immediate(): 28 return True 29 30 case Reference(): 31 return True 32 33 case Primitive(left=left, right=right): 34 return is_pure(left) and is_pure(right) 35 36 case Abstract(body=body): 37 return is_pure(body) 38 39 case Let(bindings=bindings, body=body): 40 return all(is_pure(value) for _, value in bindings) and is_pure(body) 41 42 case Branch(left=left, right=right, consequent=consequent, otherwise=otherwise): 43 return is_pure(left) and is_pure(right) and is_pure(consequent) and is_pure(otherwise) 44 45 case Load(base=base): 46 return is_pure(base) 47 48 case Begin(effects=effects, value=value): 49 return all(is_pure(effect) for effect in effects) and is_pure(value) 50 51 case Apply(): 52 return False 53 54 case Allocate(): 55 return False 56 57 case Store(): 58 return False
def
free_vars(term: Term) -> set[Identifier]:
61def free_vars(term: Term) -> set[Identifier]: 62 match term: 63 case Immediate(): 64 return set() 65 66 case Reference(name=name): 67 return {name} 68 69 case Primitive(left=left, right=right): 70 return free_vars(left) | free_vars(right) 71 72 case Apply(target=target, arguments=arguments): 73 result = free_vars(target) 74 for argument in arguments: 75 result |= free_vars(argument) 76 return result 77 78 case Abstract(parameters=parameters, body=body): 79 return free_vars(body) - set(parameters) 80 81 case Branch(left=left, right=right, consequent=consequent, otherwise=otherwise): 82 return free_vars(left) | free_vars(right) | free_vars(consequent) | free_vars(otherwise) 83 84 case Load(base=base): 85 return free_vars(base) 86 87 case Store(base=base, value=value): 88 return free_vars(base) | free_vars(value) 89 90 case Begin(effects=effects, value=value): 91 result = free_vars(value) 92 for effect in effects: 93 result |= free_vars(effect) 94 return result 95 96 case Allocate(): 97 return set() 98 99 case Let(bindings=bindings, body=body): 100 names = [name for name, _ in bindings] 101 result = free_vars(body) - set(names) 102 for _, value in bindings: 103 result |= free_vars(value) 104 return result
def
dead_code_elimination_term(term: Term, context: Context) -> Term:
107def dead_code_elimination_term( 108 term: Term, 109 context: Context, 110) -> Term: 111 recur = partial(dead_code_elimination_term, context=context) 112 113 match term: 114 case Let(bindings=bindings, body=body): 115 new_values = [(name, recur(value)) for name, value in bindings] 116 new_body = recur(body) 117 118 live = free_vars(new_body) 119 kept_reversed: list[tuple[Identifier, Term]] = [] 120 121 for name, value in reversed(new_values): 122 if name in live or not is_pure(value): 123 kept_reversed.append((name, value)) 124 live.discard(name) 125 live |= free_vars(value) 126 127 kept = list(reversed(kept_reversed)) 128 if len(kept) == 0: 129 return new_body 130 131 return Let(bindings=kept, body=new_body) 132 133 case Reference(name=_name): 134 return term 135 136 case Abstract(parameters=parameters, body=body): 137 return Abstract(parameters=parameters, body=recur(body)) 138 139 case Apply(target=target, arguments=arguments): 140 return Apply( 141 target=recur(target), 142 arguments=recur_terms(arguments, recur), 143 ) 144 145 case Immediate(value=_value): 146 return term 147 148 case Primitive(operator=operator, left=left, right=right): 149 return Primitive( 150 operator=operator, 151 left=recur(left), 152 right=recur(right), 153 ) 154 155 case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise): 156 return Branch( 157 operator=operator, 158 left=recur(left), 159 right=recur(right), 160 consequent=recur(consequent), 161 otherwise=recur(otherwise), 162 ) 163 164 case Allocate(count=count): 165 return Allocate(count=count) 166 167 case Load(base=base, index=index): 168 return Load(base=recur(base), index=index) 169 170 case Store(base=base, index=index, value=value): 171 return Store(base=recur(base), index=index, value=recur(value)) 172 173 case Begin(effects=effects, value=value): # pragma: no branch 174 new_effects = [recur(effect) for effect in effects] 175 kept_effects = [effect for effect in new_effects if not is_pure(effect)] 176 new_value = recur(value) 177 178 if len(kept_effects) == 0: 179 return new_value 180 181 return Begin(effects=kept_effects, value=new_value)