L3.eliminate_letrec

  1# noqa: F841
  2from collections.abc import Mapping
  3
  4from L2 import syntax as L2
  5
  6from . import syntax as L3
  7
  8type Context = Mapping[L3.Identifier, None]
  9
 10
 11def eliminate_letrec_term(
 12    term: L3.Term,
 13    context: Context,
 14) -> L2.Term:
 15    match term:
 16        case L3.Let(bindings=bindings, body=body):
 17            return L2.Let(
 18                bindings=[(name, eliminate_letrec_term(value, context)) for name, value in bindings],
 19                body=eliminate_letrec_term(body, context),
 20            )
 21
 22        case L3.LetRec(bindings=bindings, body=body):
 23            # Mark all binding names as recursive in the context
 24            binding_names = [name for name, _ in bindings]
 25            new_context: Context = {**context, **dict.fromkeys(binding_names)}  # type: ignore
 26
 27            # Check which bindings need heap allocation based on their values
 28            # Simple values (Immediate, Allocate) can be stored directly
 29            # Complex values (everything else) need Allocate + Store
 30            simple_binding_indices: set[int] = set()
 31            for i, (_, value) in enumerate(bindings):
 32                match value:
 33                    case L3.Immediate() | L3.Allocate():
 34                        simple_binding_indices.add(i)
 35                    case _:
 36                        pass
 37
 38            # Separate simple and complex bindings
 39            simple_bindings: list[tuple[str, L2.Term]] = []
 40            complex_bindings: list[tuple[str, L3.Term]] = []
 41            complex_binding_names: list[str] = []
 42
 43            for i, (name, value) in enumerate(bindings):
 44                if i in simple_binding_indices:
 45                    transformed_value = eliminate_letrec_term(value, new_context)
 46                    simple_bindings.append((name, transformed_value))
 47                else:
 48                    complex_bindings.append((name, value))
 49                    complex_binding_names.append(name)
 50
 51            # Create stores for complex bindings
 52            stores: list[L2.Term] = []
 53            for name, value in complex_bindings:
 54                transformed_value = eliminate_letrec_term(value, new_context)
 55                stores.append(
 56                    L2.Store(
 57                        base=L2.Reference(name=name),
 58                        index=0,
 59                        value=transformed_value,
 60                    )
 61                )
 62
 63            # Transform the body
 64            transformed_body = eliminate_letrec_term(body, new_context)
 65
 66            # Build the result
 67            all_bindings = simple_bindings + [(name, L2.Allocate(count=1)) for name in complex_binding_names]
 68
 69            if stores:
 70                return L2.Let(
 71                    bindings=all_bindings,
 72                    body=L2.Begin(
 73                        effects=stores,
 74                        value=transformed_body,
 75                    ),
 76                )
 77            else:
 78                return L2.Let(
 79                    bindings=all_bindings,
 80                    body=transformed_body,
 81                )
 82
 83        case L3.Reference(name=name):
 84            # if name is a recursive variable -> (Load (Reference name)))
 85            # else (Reference name)
 86            if name in context:
 87                return L2.Load(base=L2.Reference(name=name), index=0)
 88            else:
 89                return L2.Reference(name=name)
 90
 91        case L3.Abstract(parameters=parameters, body=body):
 92            return L2.Abstract(parameters=parameters, body=eliminate_letrec_term(body, context))
 93
 94        case L3.Apply(target=target, arguments=arguments):
 95            return L2.Apply(
 96                target=eliminate_letrec_term(target, context),
 97                arguments=[eliminate_letrec_term(argument, context) for argument in arguments],
 98            )
 99
100        case L3.Immediate(value=value):
101            return L2.Immediate(value=value)
102
103        case L3.Primitive(operator=operator, left=left, right=right):
104            return L2.Primitive(
105                operator=operator,
106                left=eliminate_letrec_term(left, context),
107                right=eliminate_letrec_term(right, context),
108            )
109
110        case L3.Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
111            return L2.Branch(
112                operator=operator,
113                left=eliminate_letrec_term(left, context),
114                right=eliminate_letrec_term(right, context),
115                consequent=eliminate_letrec_term(consequent, context),
116                otherwise=eliminate_letrec_term(otherwise, context),
117            )
118
119        case L3.Allocate(count=count):
120            return L2.Allocate(count=count)
121
122        case L3.Load(base=base, index=index):
123            return L2.Load(
124                base=eliminate_letrec_term(base, context),
125                index=index,
126            )
127
128        case L3.Store(base=base, index=index, value=value):
129            return L2.Store(
130                base=eliminate_letrec_term(base, context),
131                index=index,
132                value=eliminate_letrec_term(value, context),
133            )
134
135        case L3.Begin(effects=effects, value=value):  # pragma: no branch
136            return L2.Begin(
137                effects=[eliminate_letrec_term(effect, context) for effect in effects],
138                value=eliminate_letrec_term(value, context),
139            )
140
141
142def eliminate_letrec_program(
143    program: L3.Program,
144) -> L2.Program:
145    match program:
146        case L3.Program(parameters=parameters, body=body):  # pragma: no branch
147            return L2.Program(
148                parameters=parameters,
149                body=eliminate_letrec_term(body, {}),
150            )
type Context = Mapping[Identifier, None]
def eliminate_letrec_term(term: Term, context: Context) -> Term:
 12def eliminate_letrec_term(
 13    term: L3.Term,
 14    context: Context,
 15) -> L2.Term:
 16    match term:
 17        case L3.Let(bindings=bindings, body=body):
 18            return L2.Let(
 19                bindings=[(name, eliminate_letrec_term(value, context)) for name, value in bindings],
 20                body=eliminate_letrec_term(body, context),
 21            )
 22
 23        case L3.LetRec(bindings=bindings, body=body):
 24            # Mark all binding names as recursive in the context
 25            binding_names = [name for name, _ in bindings]
 26            new_context: Context = {**context, **dict.fromkeys(binding_names)}  # type: ignore
 27
 28            # Check which bindings need heap allocation based on their values
 29            # Simple values (Immediate, Allocate) can be stored directly
 30            # Complex values (everything else) need Allocate + Store
 31            simple_binding_indices: set[int] = set()
 32            for i, (_, value) in enumerate(bindings):
 33                match value:
 34                    case L3.Immediate() | L3.Allocate():
 35                        simple_binding_indices.add(i)
 36                    case _:
 37                        pass
 38
 39            # Separate simple and complex bindings
 40            simple_bindings: list[tuple[str, L2.Term]] = []
 41            complex_bindings: list[tuple[str, L3.Term]] = []
 42            complex_binding_names: list[str] = []
 43
 44            for i, (name, value) in enumerate(bindings):
 45                if i in simple_binding_indices:
 46                    transformed_value = eliminate_letrec_term(value, new_context)
 47                    simple_bindings.append((name, transformed_value))
 48                else:
 49                    complex_bindings.append((name, value))
 50                    complex_binding_names.append(name)
 51
 52            # Create stores for complex bindings
 53            stores: list[L2.Term] = []
 54            for name, value in complex_bindings:
 55                transformed_value = eliminate_letrec_term(value, new_context)
 56                stores.append(
 57                    L2.Store(
 58                        base=L2.Reference(name=name),
 59                        index=0,
 60                        value=transformed_value,
 61                    )
 62                )
 63
 64            # Transform the body
 65            transformed_body = eliminate_letrec_term(body, new_context)
 66
 67            # Build the result
 68            all_bindings = simple_bindings + [(name, L2.Allocate(count=1)) for name in complex_binding_names]
 69
 70            if stores:
 71                return L2.Let(
 72                    bindings=all_bindings,
 73                    body=L2.Begin(
 74                        effects=stores,
 75                        value=transformed_body,
 76                    ),
 77                )
 78            else:
 79                return L2.Let(
 80                    bindings=all_bindings,
 81                    body=transformed_body,
 82                )
 83
 84        case L3.Reference(name=name):
 85            # if name is a recursive variable -> (Load (Reference name)))
 86            # else (Reference name)
 87            if name in context:
 88                return L2.Load(base=L2.Reference(name=name), index=0)
 89            else:
 90                return L2.Reference(name=name)
 91
 92        case L3.Abstract(parameters=parameters, body=body):
 93            return L2.Abstract(parameters=parameters, body=eliminate_letrec_term(body, context))
 94
 95        case L3.Apply(target=target, arguments=arguments):
 96            return L2.Apply(
 97                target=eliminate_letrec_term(target, context),
 98                arguments=[eliminate_letrec_term(argument, context) for argument in arguments],
 99            )
100
101        case L3.Immediate(value=value):
102            return L2.Immediate(value=value)
103
104        case L3.Primitive(operator=operator, left=left, right=right):
105            return L2.Primitive(
106                operator=operator,
107                left=eliminate_letrec_term(left, context),
108                right=eliminate_letrec_term(right, context),
109            )
110
111        case L3.Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
112            return L2.Branch(
113                operator=operator,
114                left=eliminate_letrec_term(left, context),
115                right=eliminate_letrec_term(right, context),
116                consequent=eliminate_letrec_term(consequent, context),
117                otherwise=eliminate_letrec_term(otherwise, context),
118            )
119
120        case L3.Allocate(count=count):
121            return L2.Allocate(count=count)
122
123        case L3.Load(base=base, index=index):
124            return L2.Load(
125                base=eliminate_letrec_term(base, context),
126                index=index,
127            )
128
129        case L3.Store(base=base, index=index, value=value):
130            return L2.Store(
131                base=eliminate_letrec_term(base, context),
132                index=index,
133                value=eliminate_letrec_term(value, context),
134            )
135
136        case L3.Begin(effects=effects, value=value):  # pragma: no branch
137            return L2.Begin(
138                effects=[eliminate_letrec_term(effect, context) for effect in effects],
139                value=eliminate_letrec_term(value, context),
140            )
def eliminate_letrec_program(program: L3.syntax.Program) -> L2.syntax.Program:
143def eliminate_letrec_program(
144    program: L3.Program,
145) -> L2.Program:
146    match program:
147        case L3.Program(parameters=parameters, body=body):  # pragma: no branch
148            return L2.Program(
149                parameters=parameters,
150                body=eliminate_letrec_term(body, {}),
151            )