L2.cps_convert

  1from collections.abc import Callable, Sequence
  2from functools import partial
  3
  4from L1 import syntax as L1
  5
  6from L2 import syntax as L2
  7
  8
  9def cps_convert_term(
 10    term: L2.Term,
 11    m: Callable[[L1.Identifier], L1.Statement],
 12    fresh: Callable[[str], str],
 13) -> L1.Statement:
 14    _term = partial(cps_convert_term, fresh=fresh)
 15    _terms = partial(cps_convert_terms, fresh=fresh)
 16
 17    match term:
 18        case L2.Let(bindings=bindings, body=body):
 19            result = _term(body, m)
 20
 21            for name, value in reversed(bindings):
 22                result = _term(value, lambda value: L1.Copy(destination=name, source=value, then=result))
 23
 24            return result
 25
 26        case L2.Reference(name=name):
 27            return m(name)
 28
 29        case L2.Abstract(parameters=parameters, body=body):
 30            tmp = fresh("t")
 31            k = fresh("k")
 32            return L1.Abstract(
 33                destination=tmp,
 34                parameters=[*parameters, k],
 35                body=_term(body, lambda body: L1.Apply(target=k, arguments=[body])),
 36                then=m(tmp),
 37            )
 38
 39        case L2.Apply(target=target, arguments=arguments):
 40            k = fresh("k")
 41            tmp = fresh("t")
 42            return L1.Abstract(
 43                destination=k,
 44                parameters=[tmp],
 45                body=m(tmp),
 46                then=_term(
 47                    target,
 48                    lambda target: _terms(
 49                        arguments,
 50                        lambda arguments: L1.Apply(target=target, arguments=[*arguments, k]),
 51                    ),
 52                ),
 53            )
 54
 55        case L2.Immediate(value=value):
 56            tmp = fresh("t")
 57            return L1.Immediate(destination=tmp, value=value, then=m(tmp))
 58
 59        case L2.Primitive(operator=operator, left=left, right=right):
 60            tmp = fresh("t")
 61            return _term(
 62                left,
 63                lambda left: _term(
 64                    right,
 65                    lambda right: L1.Primitive(destination=tmp, operator=operator, left=left, right=right, then=m(tmp)),
 66                ),
 67            )
 68
 69        case L2.Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
 70            j = fresh("j")
 71            tmp = fresh("t")
 72            return L1.Abstract(
 73                destination=j,
 74                parameters=[tmp],
 75                body=m(tmp),
 76                then=_term(
 77                    left,
 78                    lambda left: _term(
 79                        right,
 80                        lambda right: L1.Branch(
 81                            operator=operator,
 82                            left=left,
 83                            right=right,
 84                            then=_term(consequent, lambda consequent: L1.Apply(target=j, arguments=[consequent])),
 85                            otherwise=_term(otherwise, lambda otherwise: L1.Apply(target=j, arguments=[otherwise])),
 86                        ),
 87                    ),
 88                ),
 89            )
 90
 91        case L2.Allocate(count=count):
 92            tmp = fresh("t")
 93            return L1.Allocate(destination=tmp, count=count, then=m(tmp))
 94
 95        case L2.Load(base=base, index=index):
 96            tmp = fresh("t")
 97            return _term(
 98                base,
 99                lambda base: L1.Load(destination=tmp, base=base, index=index, then=m(tmp)),
100            )
101
102        # Should double check this
103        case L2.Store(base=base, index=index, value=value):
104            tmp = fresh("t")
105            return _term(
106                base,
107                lambda base: _term(
108                    value,
109                    lambda value: L1.Store(
110                        base=base, index=index, value=value, then=L1.Immediate(destination=tmp, value=0, then=m(tmp))
111                    ),
112                ),
113            )
114
115        case L2.Begin(effects=effects, value=value):  # pragma: no branch
116            return _terms(
117                effects,
118                lambda effects: _term(
119                    value,
120                    lambda value: m(value),
121                ),
122            )
123
124
125def cps_convert_terms(
126    terms: Sequence[L2.Term],
127    k: Callable[[Sequence[L1.Identifier]], L1.Statement],
128    fresh: Callable[[str], str],
129) -> L1.Statement:
130    _term = partial(cps_convert_term, fresh=fresh)
131    _terms = partial(cps_convert_terms, fresh=fresh)
132
133    match terms:
134        case []:
135            return k([])
136
137        case [first, *rest]:
138            return _term(first, lambda first: _terms(rest, lambda rest: k([first, *rest])))
139
140        case _:  # pragma: no cover
141            raise ValueError(terms)
142
143
144def cps_convert_program(
145    program: L2.Program,
146    fresh: Callable[[str], str],
147) -> L1.Program:
148    _term = partial(cps_convert_term, fresh=fresh)
149
150    match program:
151        case L2.Program(parameters=parameters, body=body):  # pragma: no branch
152            return L1.Program(
153                parameters=parameters,
154                body=_term(body, lambda value: L1.Halt(value=value)),
155            )
def cps_convert_term( term: Term, m: Callable[[Identifier], Statement], fresh: Callable[[str], str]) -> Statement:
 10def cps_convert_term(
 11    term: L2.Term,
 12    m: Callable[[L1.Identifier], L1.Statement],
 13    fresh: Callable[[str], str],
 14) -> L1.Statement:
 15    _term = partial(cps_convert_term, fresh=fresh)
 16    _terms = partial(cps_convert_terms, fresh=fresh)
 17
 18    match term:
 19        case L2.Let(bindings=bindings, body=body):
 20            result = _term(body, m)
 21
 22            for name, value in reversed(bindings):
 23                result = _term(value, lambda value: L1.Copy(destination=name, source=value, then=result))
 24
 25            return result
 26
 27        case L2.Reference(name=name):
 28            return m(name)
 29
 30        case L2.Abstract(parameters=parameters, body=body):
 31            tmp = fresh("t")
 32            k = fresh("k")
 33            return L1.Abstract(
 34                destination=tmp,
 35                parameters=[*parameters, k],
 36                body=_term(body, lambda body: L1.Apply(target=k, arguments=[body])),
 37                then=m(tmp),
 38            )
 39
 40        case L2.Apply(target=target, arguments=arguments):
 41            k = fresh("k")
 42            tmp = fresh("t")
 43            return L1.Abstract(
 44                destination=k,
 45                parameters=[tmp],
 46                body=m(tmp),
 47                then=_term(
 48                    target,
 49                    lambda target: _terms(
 50                        arguments,
 51                        lambda arguments: L1.Apply(target=target, arguments=[*arguments, k]),
 52                    ),
 53                ),
 54            )
 55
 56        case L2.Immediate(value=value):
 57            tmp = fresh("t")
 58            return L1.Immediate(destination=tmp, value=value, then=m(tmp))
 59
 60        case L2.Primitive(operator=operator, left=left, right=right):
 61            tmp = fresh("t")
 62            return _term(
 63                left,
 64                lambda left: _term(
 65                    right,
 66                    lambda right: L1.Primitive(destination=tmp, operator=operator, left=left, right=right, then=m(tmp)),
 67                ),
 68            )
 69
 70        case L2.Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
 71            j = fresh("j")
 72            tmp = fresh("t")
 73            return L1.Abstract(
 74                destination=j,
 75                parameters=[tmp],
 76                body=m(tmp),
 77                then=_term(
 78                    left,
 79                    lambda left: _term(
 80                        right,
 81                        lambda right: L1.Branch(
 82                            operator=operator,
 83                            left=left,
 84                            right=right,
 85                            then=_term(consequent, lambda consequent: L1.Apply(target=j, arguments=[consequent])),
 86                            otherwise=_term(otherwise, lambda otherwise: L1.Apply(target=j, arguments=[otherwise])),
 87                        ),
 88                    ),
 89                ),
 90            )
 91
 92        case L2.Allocate(count=count):
 93            tmp = fresh("t")
 94            return L1.Allocate(destination=tmp, count=count, then=m(tmp))
 95
 96        case L2.Load(base=base, index=index):
 97            tmp = fresh("t")
 98            return _term(
 99                base,
100                lambda base: L1.Load(destination=tmp, base=base, index=index, then=m(tmp)),
101            )
102
103        # Should double check this
104        case L2.Store(base=base, index=index, value=value):
105            tmp = fresh("t")
106            return _term(
107                base,
108                lambda base: _term(
109                    value,
110                    lambda value: L1.Store(
111                        base=base, index=index, value=value, then=L1.Immediate(destination=tmp, value=0, then=m(tmp))
112                    ),
113                ),
114            )
115
116        case L2.Begin(effects=effects, value=value):  # pragma: no branch
117            return _terms(
118                effects,
119                lambda effects: _term(
120                    value,
121                    lambda value: m(value),
122                ),
123            )
def cps_convert_terms( terms: Sequence[Term], k: Callable[[Sequence[Identifier]], Statement], fresh: Callable[[str], str]) -> Statement:
126def cps_convert_terms(
127    terms: Sequence[L2.Term],
128    k: Callable[[Sequence[L1.Identifier]], L1.Statement],
129    fresh: Callable[[str], str],
130) -> L1.Statement:
131    _term = partial(cps_convert_term, fresh=fresh)
132    _terms = partial(cps_convert_terms, fresh=fresh)
133
134    match terms:
135        case []:
136            return k([])
137
138        case [first, *rest]:
139            return _term(first, lambda first: _terms(rest, lambda rest: k([first, *rest])))
140
141        case _:  # pragma: no cover
142            raise ValueError(terms)
def cps_convert_program( program: L2.syntax.Program, fresh: Callable[[str], str]) -> L1.syntax.Program:
145def cps_convert_program(
146    program: L2.Program,
147    fresh: Callable[[str], str],
148) -> L1.Program:
149    _term = partial(cps_convert_term, fresh=fresh)
150
151    match program:
152        case L2.Program(parameters=parameters, body=body):  # pragma: no branch
153            return L1.Program(
154                parameters=parameters,
155                body=_term(body, lambda value: L1.Halt(value=value)),
156            )