L0.to_python

  1import ast
  2from ast import stmt
  3from functools import partial
  4
  5from util.encode import encode
  6
  7from .syntax import (
  8    Address,
  9    Allocate,
 10    Branch,
 11    Call,
 12    Copy,
 13    Halt,
 14    Immediate,
 15    Load,
 16    Primitive,
 17    Procedure,
 18    Program,
 19    Statement,
 20    Store,
 21)
 22
 23
 24def load(name: str) -> ast.Name:
 25    return ast.Name(id=encode(name), ctx=ast.Load())
 26
 27
 28def store(name: str) -> ast.Name:
 29    return ast.Name(id=encode(name), ctx=ast.Store())
 30
 31
 32def to_ast_statement(
 33    term: Statement,
 34) -> list[ast.stmt]:
 35    _statement: partial[list[stmt]] = partial(to_ast_statement)
 36
 37    match term:
 38        case Copy(destination=destination, source=source, then=then):
 39            return [
 40                ast.Assign(targets=[store(destination)], value=load(source)),
 41                *_statement(then),
 42            ]
 43
 44        case Immediate(destination=destination, value=value, then=then):
 45            return [
 46                ast.Assign(targets=[store(destination)], value=ast.Constant(value=value)),
 47                *_statement(then),
 48            ]
 49
 50        case Primitive(destination=destination, operator=operator, left=left, right=right, then=then):
 51            match operator:
 52                case "+":
 53                    op = ast.Add()
 54
 55                case "-":
 56                    op = ast.Sub()
 57
 58                case "*":  # pragma: no branch
 59                    op = ast.Mult()
 60
 61            return [
 62                ast.Assign(
 63                    targets=[store(destination)],
 64                    value=ast.BinOp(
 65                        left=load(left),
 66                        op=op,
 67                        right=load(right),
 68                    ),
 69                ),
 70                *_statement(then),
 71            ]
 72
 73        case Branch(operator=operator, left=left, right=right, then=then, otherwise=otherwise):
 74            match operator:
 75                case "<":
 76                    op = ast.Lt()
 77
 78                case "==":  # pragma: no branch
 79                    op = ast.Eq()
 80
 81            return [
 82                ast.If(
 83                    test=ast.Compare(
 84                        left=load(left),
 85                        ops=[op],
 86                        comparators=[load(right)],
 87                    ),
 88                    body=_statement(then),
 89                    orelse=_statement(otherwise),
 90                ),
 91            ]
 92
 93        case Allocate(destination=destination, count=count, then=then):
 94            return [
 95                ast.Assign(
 96                    targets=[store(destination)],
 97                    value=ast.List(
 98                        elts=[ast.Constant(None) for _ in range(count)],
 99                        ctx=ast.Load(),
100                    ),
101                ),
102                *_statement(then),
103            ]
104
105        case Load(destination=destination, base=base, index=index, then=then):
106            return [
107                ast.Assign(
108                    targets=[store(destination)],
109                    value=ast.Subscript(
110                        value=load(base),
111                        slice=ast.Constant(index),
112                        ctx=ast.Load(),
113                    ),
114                ),
115                *_statement(then),
116            ]
117
118        case Store(base=base, index=index, value=value, then=then):
119            return [
120                ast.Assign(
121                    targets=[
122                        ast.Subscript(
123                            value=store(base),
124                            slice=ast.Constant(index),
125                            ctx=ast.Store(),
126                        )
127                    ],
128                    value=load(value),
129                ),
130                *_statement(then),
131            ]
132
133        case Address(destination=destination, name=name, then=then):
134            return [
135                ast.Assign(targets=[store(destination)], value=load(name)),
136                *_statement(then),
137            ]
138
139        case Call(target=target, arguments=arguments):
140            return [
141                ast.Return(
142                    value=ast.Call(
143                        func=load(target),
144                        args=[load(argument) for argument in arguments],
145                    )
146                )
147            ]
148
149        case Halt(value=value):  # pragma: no branch
150            return [
151                ast.Return(value=load(value)),
152            ]
153
154
155def to_ast_procedure(procedure: Procedure) -> ast.stmt:
156    _statement: partial[list[stmt]] = partial(to_ast_statement)
157
158    match procedure:
159        case Procedure(name=name, parameters=parameters, body=body):  # pragma: no branch
160            return ast.FunctionDef(
161                name=name,
162                args=ast.arguments(args=[ast.arg(arg=parameter) for parameter in parameters]),
163                body=_statement(body),
164            )
165
166
167def to_ast_program(
168    program: Program,
169) -> str:
170    _procedure = partial(to_ast_procedure)
171    _statement = partial(to_ast_statement)
172
173    match program:
174        case Program(procedures=procedures):  # pragma: no branch
175            l0 = next(procedure for procedure in procedures if procedure.name == "l0")
176
177            module = ast.Module(
178                body=[
179                    *[_procedure(procedure) for procedure in procedures],
180                    ast.If(
181                        test=ast.Compare(
182                            left=ast.Name(id="__name__", ctx=ast.Load()),
183                            ops=[ast.Eq()],
184                            comparators=[ast.Constant(value="__main__")],
185                        ),
186                        body=[
187                            ast.Import(names=[ast.alias(name="sys", asname=None)]),
188                            ast.Expr(
189                                value=ast.Call(
190                                    func=ast.Name(id="print", ctx=ast.Load()),
191                                    args=[
192                                        ast.Call(
193                                            func=ast.Name(id="l0", ctx=ast.Load()),
194                                            args=[
195                                                ast.Call(
196                                                    func=ast.Name(id="int", ctx=ast.Load()),
197                                                    args=[
198                                                        ast.Subscript(
199                                                            value=ast.Attribute(
200                                                                value=ast.Name(id="sys", ctx=ast.Load()),
201                                                                attr="argv",
202                                                                ctx=ast.Load(),
203                                                            ),
204                                                            slice=ast.Constant(value=i + 1),
205                                                        )
206                                                    ],
207                                                )
208                                                for i, _ in enumerate(l0.parameters)
209                                            ],
210                                        )
211                                    ],
212                                )
213                            ),
214                        ],
215                    ),
216                ]
217            )
218
219            ast.fix_missing_locations(module)
220
221            return ast.unparse(module)
def load(name: str) -> ast.Name:
25def load(name: str) -> ast.Name:
26    return ast.Name(id=encode(name), ctx=ast.Load())
def store(name: str) -> ast.Name:
29def store(name: str) -> ast.Name:
30    return ast.Name(id=encode(name), ctx=ast.Store())
def to_ast_statement(term: Statement) -> list[ast.stmt]:
 33def to_ast_statement(
 34    term: Statement,
 35) -> list[ast.stmt]:
 36    _statement: partial[list[stmt]] = partial(to_ast_statement)
 37
 38    match term:
 39        case Copy(destination=destination, source=source, then=then):
 40            return [
 41                ast.Assign(targets=[store(destination)], value=load(source)),
 42                *_statement(then),
 43            ]
 44
 45        case Immediate(destination=destination, value=value, then=then):
 46            return [
 47                ast.Assign(targets=[store(destination)], value=ast.Constant(value=value)),
 48                *_statement(then),
 49            ]
 50
 51        case Primitive(destination=destination, operator=operator, left=left, right=right, then=then):
 52            match operator:
 53                case "+":
 54                    op = ast.Add()
 55
 56                case "-":
 57                    op = ast.Sub()
 58
 59                case "*":  # pragma: no branch
 60                    op = ast.Mult()
 61
 62            return [
 63                ast.Assign(
 64                    targets=[store(destination)],
 65                    value=ast.BinOp(
 66                        left=load(left),
 67                        op=op,
 68                        right=load(right),
 69                    ),
 70                ),
 71                *_statement(then),
 72            ]
 73
 74        case Branch(operator=operator, left=left, right=right, then=then, otherwise=otherwise):
 75            match operator:
 76                case "<":
 77                    op = ast.Lt()
 78
 79                case "==":  # pragma: no branch
 80                    op = ast.Eq()
 81
 82            return [
 83                ast.If(
 84                    test=ast.Compare(
 85                        left=load(left),
 86                        ops=[op],
 87                        comparators=[load(right)],
 88                    ),
 89                    body=_statement(then),
 90                    orelse=_statement(otherwise),
 91                ),
 92            ]
 93
 94        case Allocate(destination=destination, count=count, then=then):
 95            return [
 96                ast.Assign(
 97                    targets=[store(destination)],
 98                    value=ast.List(
 99                        elts=[ast.Constant(None) for _ in range(count)],
100                        ctx=ast.Load(),
101                    ),
102                ),
103                *_statement(then),
104            ]
105
106        case Load(destination=destination, base=base, index=index, then=then):
107            return [
108                ast.Assign(
109                    targets=[store(destination)],
110                    value=ast.Subscript(
111                        value=load(base),
112                        slice=ast.Constant(index),
113                        ctx=ast.Load(),
114                    ),
115                ),
116                *_statement(then),
117            ]
118
119        case Store(base=base, index=index, value=value, then=then):
120            return [
121                ast.Assign(
122                    targets=[
123                        ast.Subscript(
124                            value=store(base),
125                            slice=ast.Constant(index),
126                            ctx=ast.Store(),
127                        )
128                    ],
129                    value=load(value),
130                ),
131                *_statement(then),
132            ]
133
134        case Address(destination=destination, name=name, then=then):
135            return [
136                ast.Assign(targets=[store(destination)], value=load(name)),
137                *_statement(then),
138            ]
139
140        case Call(target=target, arguments=arguments):
141            return [
142                ast.Return(
143                    value=ast.Call(
144                        func=load(target),
145                        args=[load(argument) for argument in arguments],
146                    )
147                )
148            ]
149
150        case Halt(value=value):  # pragma: no branch
151            return [
152                ast.Return(value=load(value)),
153            ]
def to_ast_procedure(procedure: L0.syntax.Procedure) -> ast.stmt:
156def to_ast_procedure(procedure: Procedure) -> ast.stmt:
157    _statement: partial[list[stmt]] = partial(to_ast_statement)
158
159    match procedure:
160        case Procedure(name=name, parameters=parameters, body=body):  # pragma: no branch
161            return ast.FunctionDef(
162                name=name,
163                args=ast.arguments(args=[ast.arg(arg=parameter) for parameter in parameters]),
164                body=_statement(body),
165            )
def to_ast_program(program: L0.syntax.Program) -> str:
168def to_ast_program(
169    program: Program,
170) -> str:
171    _procedure = partial(to_ast_procedure)
172    _statement = partial(to_ast_statement)
173
174    match program:
175        case Program(procedures=procedures):  # pragma: no branch
176            l0 = next(procedure for procedure in procedures if procedure.name == "l0")
177
178            module = ast.Module(
179                body=[
180                    *[_procedure(procedure) for procedure in procedures],
181                    ast.If(
182                        test=ast.Compare(
183                            left=ast.Name(id="__name__", ctx=ast.Load()),
184                            ops=[ast.Eq()],
185                            comparators=[ast.Constant(value="__main__")],
186                        ),
187                        body=[
188                            ast.Import(names=[ast.alias(name="sys", asname=None)]),
189                            ast.Expr(
190                                value=ast.Call(
191                                    func=ast.Name(id="print", ctx=ast.Load()),
192                                    args=[
193                                        ast.Call(
194                                            func=ast.Name(id="l0", ctx=ast.Load()),
195                                            args=[
196                                                ast.Call(
197                                                    func=ast.Name(id="int", ctx=ast.Load()),
198                                                    args=[
199                                                        ast.Subscript(
200                                                            value=ast.Attribute(
201                                                                value=ast.Name(id="sys", ctx=ast.Load()),
202                                                                attr="argv",
203                                                                ctx=ast.Load(),
204                                                            ),
205                                                            slice=ast.Constant(value=i + 1),
206                                                        )
207                                                    ],
208                                                )
209                                                for i, _ in enumerate(l0.parameters)
210                                            ],
211                                        )
212                                    ],
213                                )
214                            ),
215                        ],
216                    ),
217                ]
218            )
219
220            ast.fix_missing_locations(module)
221
222            return ast.unparse(module)