L3.eliminate_letrec
1# noqa: F841 2from collections.abc import Mapping 3 4from L2 import syntax as L2 5 6from . import syntax as L3 7 8type Context = Mapping[L3.Identifier, None] 9 10 11def eliminate_letrec_term( 12 term: L3.Term, 13 context: Context, 14) -> L2.Term: 15 match term: 16 case L3.Let(bindings=bindings, body=body): 17 return L2.Let( 18 bindings=[(name, eliminate_letrec_term(value, context)) for name, value in bindings], 19 body=eliminate_letrec_term(body, context), 20 ) 21 22 case L3.LetRec(bindings=bindings, body=body): 23 # Mark all binding names as recursive in the context 24 binding_names = [name for name, _ in bindings] 25 new_context: Context = {**context, **dict.fromkeys(binding_names)} # type: ignore 26 27 # Check which bindings need heap allocation based on their values 28 # Simple values (Immediate, Allocate) can be stored directly 29 # Complex values (everything else) need Allocate + Store 30 simple_binding_indices: set[int] = set() 31 for i, (_, value) in enumerate(bindings): 32 match value: 33 case L3.Immediate() | L3.Allocate(): 34 simple_binding_indices.add(i) 35 case _: 36 pass 37 38 # Separate simple and complex bindings 39 simple_bindings: list[tuple[str, L2.Term]] = [] 40 complex_bindings: list[tuple[str, L3.Term]] = [] 41 complex_binding_names: list[str] = [] 42 43 for i, (name, value) in enumerate(bindings): 44 if i in simple_binding_indices: 45 transformed_value = eliminate_letrec_term(value, new_context) 46 simple_bindings.append((name, transformed_value)) 47 else: 48 complex_bindings.append((name, value)) 49 complex_binding_names.append(name) 50 51 # Create stores for complex bindings 52 stores: list[L2.Term] = [] 53 for name, value in complex_bindings: 54 transformed_value = eliminate_letrec_term(value, new_context) 55 stores.append( 56 L2.Store( 57 base=L2.Reference(name=name), 58 index=0, 59 value=transformed_value, 60 ) 61 ) 62 63 # Transform the body 64 transformed_body = eliminate_letrec_term(body, new_context) 65 66 # Build the result 67 all_bindings = simple_bindings + [(name, L2.Allocate(count=1)) for name in complex_binding_names] 68 69 if stores: 70 return L2.Let( 71 bindings=all_bindings, 72 body=L2.Begin( 73 effects=stores, 74 value=transformed_body, 75 ), 76 ) 77 else: 78 return L2.Let( 79 bindings=all_bindings, 80 body=transformed_body, 81 ) 82 83 case L3.Reference(name=name): 84 # if name is a recursive variable -> (Load (Reference name))) 85 # else (Reference name) 86 if name in context: 87 return L2.Load(base=L2.Reference(name=name), index=0) 88 else: 89 return L2.Reference(name=name) 90 91 case L3.Abstract(parameters=parameters, body=body): 92 return L2.Abstract(parameters=parameters, body=eliminate_letrec_term(body, context)) 93 94 case L3.Apply(target=target, arguments=arguments): 95 return L2.Apply( 96 target=eliminate_letrec_term(target, context), 97 arguments=[eliminate_letrec_term(argument, context) for argument in arguments], 98 ) 99 100 case L3.Immediate(value=value): 101 return L2.Immediate(value=value) 102 103 case L3.Primitive(operator=operator, left=left, right=right): 104 return L2.Primitive( 105 operator=operator, 106 left=eliminate_letrec_term(left, context), 107 right=eliminate_letrec_term(right, context), 108 ) 109 110 case L3.Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise): 111 return L2.Branch( 112 operator=operator, 113 left=eliminate_letrec_term(left, context), 114 right=eliminate_letrec_term(right, context), 115 consequent=eliminate_letrec_term(consequent, context), 116 otherwise=eliminate_letrec_term(otherwise, context), 117 ) 118 119 case L3.Allocate(count=count): 120 return L2.Allocate(count=count) 121 122 case L3.Load(base=base, index=index): 123 return L2.Load( 124 base=eliminate_letrec_term(base, context), 125 index=index, 126 ) 127 128 case L3.Store(base=base, index=index, value=value): 129 return L2.Store( 130 base=eliminate_letrec_term(base, context), 131 index=index, 132 value=eliminate_letrec_term(value, context), 133 ) 134 135 case L3.Begin(effects=effects, value=value): # pragma: no branch 136 return L2.Begin( 137 effects=[eliminate_letrec_term(effect, context) for effect in effects], 138 value=eliminate_letrec_term(value, context), 139 ) 140 141 142def eliminate_letrec_program( 143 program: L3.Program, 144) -> L2.Program: 145 match program: 146 case L3.Program(parameters=parameters, body=body): # pragma: no branch 147 return L2.Program( 148 parameters=parameters, 149 body=eliminate_letrec_term(body, {}), 150 )
type Context =
Mapping[Identifier, None]
def
eliminate_letrec_term(term: Term, context: Context) -> Term:
12def eliminate_letrec_term( 13 term: L3.Term, 14 context: Context, 15) -> L2.Term: 16 match term: 17 case L3.Let(bindings=bindings, body=body): 18 return L2.Let( 19 bindings=[(name, eliminate_letrec_term(value, context)) for name, value in bindings], 20 body=eliminate_letrec_term(body, context), 21 ) 22 23 case L3.LetRec(bindings=bindings, body=body): 24 # Mark all binding names as recursive in the context 25 binding_names = [name for name, _ in bindings] 26 new_context: Context = {**context, **dict.fromkeys(binding_names)} # type: ignore 27 28 # Check which bindings need heap allocation based on their values 29 # Simple values (Immediate, Allocate) can be stored directly 30 # Complex values (everything else) need Allocate + Store 31 simple_binding_indices: set[int] = set() 32 for i, (_, value) in enumerate(bindings): 33 match value: 34 case L3.Immediate() | L3.Allocate(): 35 simple_binding_indices.add(i) 36 case _: 37 pass 38 39 # Separate simple and complex bindings 40 simple_bindings: list[tuple[str, L2.Term]] = [] 41 complex_bindings: list[tuple[str, L3.Term]] = [] 42 complex_binding_names: list[str] = [] 43 44 for i, (name, value) in enumerate(bindings): 45 if i in simple_binding_indices: 46 transformed_value = eliminate_letrec_term(value, new_context) 47 simple_bindings.append((name, transformed_value)) 48 else: 49 complex_bindings.append((name, value)) 50 complex_binding_names.append(name) 51 52 # Create stores for complex bindings 53 stores: list[L2.Term] = [] 54 for name, value in complex_bindings: 55 transformed_value = eliminate_letrec_term(value, new_context) 56 stores.append( 57 L2.Store( 58 base=L2.Reference(name=name), 59 index=0, 60 value=transformed_value, 61 ) 62 ) 63 64 # Transform the body 65 transformed_body = eliminate_letrec_term(body, new_context) 66 67 # Build the result 68 all_bindings = simple_bindings + [(name, L2.Allocate(count=1)) for name in complex_binding_names] 69 70 if stores: 71 return L2.Let( 72 bindings=all_bindings, 73 body=L2.Begin( 74 effects=stores, 75 value=transformed_body, 76 ), 77 ) 78 else: 79 return L2.Let( 80 bindings=all_bindings, 81 body=transformed_body, 82 ) 83 84 case L3.Reference(name=name): 85 # if name is a recursive variable -> (Load (Reference name))) 86 # else (Reference name) 87 if name in context: 88 return L2.Load(base=L2.Reference(name=name), index=0) 89 else: 90 return L2.Reference(name=name) 91 92 case L3.Abstract(parameters=parameters, body=body): 93 return L2.Abstract(parameters=parameters, body=eliminate_letrec_term(body, context)) 94 95 case L3.Apply(target=target, arguments=arguments): 96 return L2.Apply( 97 target=eliminate_letrec_term(target, context), 98 arguments=[eliminate_letrec_term(argument, context) for argument in arguments], 99 ) 100 101 case L3.Immediate(value=value): 102 return L2.Immediate(value=value) 103 104 case L3.Primitive(operator=operator, left=left, right=right): 105 return L2.Primitive( 106 operator=operator, 107 left=eliminate_letrec_term(left, context), 108 right=eliminate_letrec_term(right, context), 109 ) 110 111 case L3.Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise): 112 return L2.Branch( 113 operator=operator, 114 left=eliminate_letrec_term(left, context), 115 right=eliminate_letrec_term(right, context), 116 consequent=eliminate_letrec_term(consequent, context), 117 otherwise=eliminate_letrec_term(otherwise, context), 118 ) 119 120 case L3.Allocate(count=count): 121 return L2.Allocate(count=count) 122 123 case L3.Load(base=base, index=index): 124 return L2.Load( 125 base=eliminate_letrec_term(base, context), 126 index=index, 127 ) 128 129 case L3.Store(base=base, index=index, value=value): 130 return L2.Store( 131 base=eliminate_letrec_term(base, context), 132 index=index, 133 value=eliminate_letrec_term(value, context), 134 ) 135 136 case L3.Begin(effects=effects, value=value): # pragma: no branch 137 return L2.Begin( 138 effects=[eliminate_letrec_term(effect, context) for effect in effects], 139 value=eliminate_letrec_term(value, context), 140 )