L3.to_python

  1import ast
  2from functools import partial
  3
  4from util.encode import encode
  5
  6from .syntax import (
  7    Abstract,
  8    Allocate,
  9    Apply,
 10    Begin,
 11    Branch,
 12    Immediate,
 13    Let,
 14    LetRec,
 15    Load,
 16    Primitive,
 17    Program,
 18    Reference,
 19    Store,
 20    Term,
 21)
 22
 23
 24def to_ast_term(
 25    term: Term,
 26) -> ast.expr:
 27    _term = partial(to_ast_term)
 28
 29    match term:
 30        case Let(bindings=bindings, body=body):
 31            return ast.Subscript(
 32                value=ast.Tuple(
 33                    elts=[
 34                        *[
 35                            ast.NamedExpr(target=ast.Name(id=encode(name), ctx=ast.Store()), value=_term(value))
 36                            for name, value in bindings
 37                        ],
 38                        _term(body),
 39                    ],
 40                    ctx=ast.Load(),
 41                ),
 42                slice=ast.Constant(-1),
 43                ctx=ast.Load(),
 44            )
 45
 46        case LetRec(bindings=bindings, body=body):
 47            return ast.Subscript(
 48                value=ast.Tuple(
 49                    elts=[
 50                        *[
 51                            ast.NamedExpr(target=ast.Name(id=encode(name), ctx=ast.Store()), value=ast.Constant(None))
 52                            for name, _value in bindings
 53                        ],
 54                        *[
 55                            ast.NamedExpr(target=ast.Name(id=encode(name), ctx=ast.Store()), value=_term(value))
 56                            for name, value in bindings
 57                        ],
 58                        _term(body),
 59                    ],
 60                    ctx=ast.Load(),
 61                ),
 62                slice=ast.Constant(-1),
 63                ctx=ast.Load(),
 64            )
 65
 66        case Reference(name=name):
 67            return ast.Name(id=encode(name), ctx=ast.Load())
 68
 69        case Abstract(parameters=parameters, body=body):
 70            return ast.Lambda(
 71                args=ast.arguments(args=[ast.arg(arg=encode(parameter)) for parameter in parameters]),
 72                body=_term(body),
 73            )
 74
 75        case Apply(target=target, arguments=arguments):
 76            return ast.Call(
 77                func=_term(target),
 78                args=[_term(argument) for argument in arguments],
 79            )
 80
 81        case Immediate(value=value):
 82            return ast.Constant(value=value)
 83
 84        case Primitive(operator=operator, left=left, right=right):
 85            match operator:
 86                case "+":
 87                    op = ast.Add()
 88
 89                case "-":
 90                    op = ast.Sub()
 91
 92                case "*":  # pragma: no branch
 93                    op = ast.Mult()
 94
 95            return ast.BinOp(
 96                left=_term(left),
 97                op=op,
 98                right=_term(right),
 99            )
100
101        case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
102            match operator:
103                case "<":
104                    op = ast.Lt()
105
106                case "==":  # pragma: no branch
107                    op = ast.Eq()
108
109            return ast.IfExp(
110                test=ast.Compare(
111                    left=_term(left),
112                    ops=[op],
113                    comparators=[_term(right)],
114                ),
115                body=_term(consequent),
116                orelse=_term(otherwise),
117            )
118
119        case Allocate(count=count):
120            return ast.List(
121                elts=[ast.Constant(None) for _ in range(count)],
122                ctx=ast.Load(),
123            )
124
125        case Load(base=base, index=index):
126            return ast.Call(
127                func=ast.Attribute(value=_term(base), attr="__getitem__", ctx=ast.Load()),
128                args=[ast.Constant(value=index)],
129            )
130
131        case Store(base=base, index=index, value=value):
132            return ast.Subscript(
133                value=ast.Tuple(
134                    elts=[
135                        ast.Call(
136                            func=ast.Attribute(value=_term(base), attr="__setitem__", ctx=ast.Load()),
137                            args=[ast.Constant(value=index), _term(value)],
138                        ),
139                        ast.Constant(value=0),
140                    ],
141                    ctx=ast.Load(),
142                ),
143                slice=ast.Constant(-1),
144                ctx=ast.Load(),
145            )
146
147        case Begin(effects=effects, value=value):  # pragma: no branch
148            return ast.Subscript(
149                value=ast.Tuple(
150                    elts=[
151                        *[_term(effect) for effect in effects],
152                        _term(value),
153                    ],
154                    ctx=ast.Load(),
155                ),
156                slice=ast.Constant(-1),
157                ctx=ast.Load(),
158            )
159
160
161def to_ast_program(
162    program: Program,
163) -> str:
164    match program:
165        case Program(parameters=parameters, body=body):  # pragma: no branch
166            module = ast.Module(
167                body=[
168                    ast.FunctionDef(
169                        name="l3",
170                        args=ast.arguments(args=[ast.arg(arg=encode(parameter)) for parameter in parameters]),
171                        body=[
172                            ast.Return(value=to_ast_term(body)),
173                        ],
174                    ),
175                    ast.If(
176                        test=ast.Compare(
177                            left=ast.Name(id="__name__", ctx=ast.Load()),
178                            ops=[ast.Eq()],
179                            comparators=[ast.Constant(value="__main__")],
180                        ),
181                        body=[
182                            ast.Import(names=[ast.alias(name="sys", asname=None)]),
183                            ast.Expr(
184                                value=ast.Call(
185                                    func=ast.Name(id="print", ctx=ast.Load()),
186                                    args=[
187                                        ast.Call(
188                                            func=ast.Name(id="l3", ctx=ast.Load()),
189                                            args=[
190                                                ast.Call(
191                                                    func=ast.Name(id="int", ctx=ast.Load()),
192                                                    args=[
193                                                        ast.Subscript(
194                                                            value=ast.Attribute(
195                                                                value=ast.Name(id="sys", ctx=ast.Load()),
196                                                                attr="argv",
197                                                                ctx=ast.Load(),
198                                                            ),
199                                                            slice=ast.Constant(value=i + 1),
200                                                        )
201                                                    ],
202                                                )
203                                                for i, _ in enumerate(parameters)
204                                            ],
205                                        )
206                                    ],
207                                )
208                            ),
209                        ],
210                    ),
211                ]
212            )
213
214            ast.fix_missing_locations(module)
215
216            return ast.unparse(module)
def to_ast_term(term: Term) -> ast.expr:
 25def to_ast_term(
 26    term: Term,
 27) -> ast.expr:
 28    _term = partial(to_ast_term)
 29
 30    match term:
 31        case Let(bindings=bindings, body=body):
 32            return ast.Subscript(
 33                value=ast.Tuple(
 34                    elts=[
 35                        *[
 36                            ast.NamedExpr(target=ast.Name(id=encode(name), ctx=ast.Store()), value=_term(value))
 37                            for name, value in bindings
 38                        ],
 39                        _term(body),
 40                    ],
 41                    ctx=ast.Load(),
 42                ),
 43                slice=ast.Constant(-1),
 44                ctx=ast.Load(),
 45            )
 46
 47        case LetRec(bindings=bindings, body=body):
 48            return ast.Subscript(
 49                value=ast.Tuple(
 50                    elts=[
 51                        *[
 52                            ast.NamedExpr(target=ast.Name(id=encode(name), ctx=ast.Store()), value=ast.Constant(None))
 53                            for name, _value in bindings
 54                        ],
 55                        *[
 56                            ast.NamedExpr(target=ast.Name(id=encode(name), ctx=ast.Store()), value=_term(value))
 57                            for name, value in bindings
 58                        ],
 59                        _term(body),
 60                    ],
 61                    ctx=ast.Load(),
 62                ),
 63                slice=ast.Constant(-1),
 64                ctx=ast.Load(),
 65            )
 66
 67        case Reference(name=name):
 68            return ast.Name(id=encode(name), ctx=ast.Load())
 69
 70        case Abstract(parameters=parameters, body=body):
 71            return ast.Lambda(
 72                args=ast.arguments(args=[ast.arg(arg=encode(parameter)) for parameter in parameters]),
 73                body=_term(body),
 74            )
 75
 76        case Apply(target=target, arguments=arguments):
 77            return ast.Call(
 78                func=_term(target),
 79                args=[_term(argument) for argument in arguments],
 80            )
 81
 82        case Immediate(value=value):
 83            return ast.Constant(value=value)
 84
 85        case Primitive(operator=operator, left=left, right=right):
 86            match operator:
 87                case "+":
 88                    op = ast.Add()
 89
 90                case "-":
 91                    op = ast.Sub()
 92
 93                case "*":  # pragma: no branch
 94                    op = ast.Mult()
 95
 96            return ast.BinOp(
 97                left=_term(left),
 98                op=op,
 99                right=_term(right),
100            )
101
102        case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
103            match operator:
104                case "<":
105                    op = ast.Lt()
106
107                case "==":  # pragma: no branch
108                    op = ast.Eq()
109
110            return ast.IfExp(
111                test=ast.Compare(
112                    left=_term(left),
113                    ops=[op],
114                    comparators=[_term(right)],
115                ),
116                body=_term(consequent),
117                orelse=_term(otherwise),
118            )
119
120        case Allocate(count=count):
121            return ast.List(
122                elts=[ast.Constant(None) for _ in range(count)],
123                ctx=ast.Load(),
124            )
125
126        case Load(base=base, index=index):
127            return ast.Call(
128                func=ast.Attribute(value=_term(base), attr="__getitem__", ctx=ast.Load()),
129                args=[ast.Constant(value=index)],
130            )
131
132        case Store(base=base, index=index, value=value):
133            return ast.Subscript(
134                value=ast.Tuple(
135                    elts=[
136                        ast.Call(
137                            func=ast.Attribute(value=_term(base), attr="__setitem__", ctx=ast.Load()),
138                            args=[ast.Constant(value=index), _term(value)],
139                        ),
140                        ast.Constant(value=0),
141                    ],
142                    ctx=ast.Load(),
143                ),
144                slice=ast.Constant(-1),
145                ctx=ast.Load(),
146            )
147
148        case Begin(effects=effects, value=value):  # pragma: no branch
149            return ast.Subscript(
150                value=ast.Tuple(
151                    elts=[
152                        *[_term(effect) for effect in effects],
153                        _term(value),
154                    ],
155                    ctx=ast.Load(),
156                ),
157                slice=ast.Constant(-1),
158                ctx=ast.Load(),
159            )
def to_ast_program(program: L3.syntax.Program) -> str:
162def to_ast_program(
163    program: Program,
164) -> str:
165    match program:
166        case Program(parameters=parameters, body=body):  # pragma: no branch
167            module = ast.Module(
168                body=[
169                    ast.FunctionDef(
170                        name="l3",
171                        args=ast.arguments(args=[ast.arg(arg=encode(parameter)) for parameter in parameters]),
172                        body=[
173                            ast.Return(value=to_ast_term(body)),
174                        ],
175                    ),
176                    ast.If(
177                        test=ast.Compare(
178                            left=ast.Name(id="__name__", ctx=ast.Load()),
179                            ops=[ast.Eq()],
180                            comparators=[ast.Constant(value="__main__")],
181                        ),
182                        body=[
183                            ast.Import(names=[ast.alias(name="sys", asname=None)]),
184                            ast.Expr(
185                                value=ast.Call(
186                                    func=ast.Name(id="print", ctx=ast.Load()),
187                                    args=[
188                                        ast.Call(
189                                            func=ast.Name(id="l3", ctx=ast.Load()),
190                                            args=[
191                                                ast.Call(
192                                                    func=ast.Name(id="int", ctx=ast.Load()),
193                                                    args=[
194                                                        ast.Subscript(
195                                                            value=ast.Attribute(
196                                                                value=ast.Name(id="sys", ctx=ast.Load()),
197                                                                attr="argv",
198                                                                ctx=ast.Load(),
199                                                            ),
200                                                            slice=ast.Constant(value=i + 1),
201                                                        )
202                                                    ],
203                                                )
204                                                for i, _ in enumerate(parameters)
205                                            ],
206                                        )
207                                    ],
208                                )
209                            ),
210                        ],
211                    ),
212                ]
213            )
214
215            ast.fix_missing_locations(module)
216
217            return ast.unparse(module)