L2.to_python

  1import ast
  2from functools import partial
  3
  4from util.encode import encode
  5
  6from .syntax import (
  7    Abstract,
  8    Allocate,
  9    Apply,
 10    Begin,
 11    Branch,
 12    Immediate,
 13    Let,
 14    Load,
 15    Primitive,
 16    Program,
 17    Reference,
 18    Store,
 19    Term,
 20)
 21
 22
 23def to_ast_term(
 24    term: Term,
 25) -> ast.expr:
 26    _term = partial(to_ast_term)
 27
 28    match term:
 29        case Let(bindings=bindings, body=body):
 30            return ast.Subscript(
 31                value=ast.Tuple(
 32                    elts=[
 33                        *[
 34                            ast.NamedExpr(target=ast.Name(id=encode(name), ctx=ast.Store()), value=_term(value))
 35                            for name, value in bindings
 36                        ],
 37                        _term(body),
 38                    ],
 39                    ctx=ast.Load(),
 40                ),
 41                slice=ast.Constant(-1),
 42                ctx=ast.Load(),
 43            )
 44
 45        case Reference(name=name):
 46            return ast.Name(id=encode(name), ctx=ast.Load())
 47
 48        case Abstract(parameters=parameters, body=body):
 49            return ast.Lambda(
 50                args=ast.arguments(args=[ast.arg(arg=parameter) for parameter in parameters]),
 51                body=_term(body),
 52            )
 53
 54        case Apply(target=target, arguments=arguments):
 55            return ast.Call(
 56                func=_term(target),
 57                args=[_term(argument) for argument in arguments],
 58            )
 59
 60        case Immediate(value=value):
 61            return ast.Constant(value=value)
 62
 63        case Primitive(operator=operator, left=left, right=right):
 64            match operator:
 65                case "+":
 66                    op = ast.Add()
 67
 68                case "-":
 69                    op = ast.Sub()
 70
 71                case "*":  # pragma: no branch
 72                    op = ast.Mult()
 73
 74            return ast.BinOp(left=_term(left), op=op, right=_term(right))
 75
 76        case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
 77            match operator:
 78                case "<":
 79                    op = ast.Lt()
 80
 81                case "==":  # pragma: no branch
 82                    op = ast.Eq()
 83
 84            return ast.IfExp(
 85                test=ast.Compare(left=_term(left), ops=[op], comparators=[_term(right)]),
 86                body=_term(consequent),
 87                orelse=_term(otherwise),
 88            )
 89
 90        case Allocate(count=count):
 91            return ast.List(
 92                elts=[ast.Constant(None) for _ in range(count)],
 93                ctx=ast.Load(),
 94            )
 95
 96        case Load(base=base, index=index):
 97            return ast.Call(
 98                func=ast.Attribute(value=_term(base), attr="__getitem__", ctx=ast.Load()),
 99                args=[ast.Constant(value=index)],
100            )
101
102        case Store(base=base, index=index, value=value):
103            return ast.Subscript(
104                value=ast.Tuple(
105                    elts=[
106                        ast.Call(
107                            func=ast.Attribute(value=_term(base), attr="__setitem__", ctx=ast.Load()),
108                            args=[ast.Constant(value=index), _term(value)],
109                        ),
110                        ast.Constant(value=0),
111                    ],
112                    ctx=ast.Load(),
113                ),
114                slice=ast.Constant(-1),
115                ctx=ast.Load(),
116            )
117
118        case Begin(effects=effects, value=value):  # pragma: no branch
119            return ast.Subscript(
120                value=ast.Tuple(
121                    elts=[
122                        *[_term(effect) for effect in effects],
123                        _term(value),
124                    ],
125                    ctx=ast.Load(),
126                ),
127                slice=ast.Constant(-1),
128                ctx=ast.Load(),
129            )
130
131
132def to_ast_program(
133    program: Program,
134) -> str:
135    match program:
136        case Program(parameters=parameters, body=body):  # pragma: no branch
137            module = ast.Module(
138                body=[
139                    ast.FunctionDef(
140                        name="l2",
141                        args=ast.arguments(args=[ast.arg(arg=parameter) for parameter in parameters]),
142                        body=[
143                            ast.Return(value=to_ast_term(body)),
144                        ],
145                    ),
146                    ast.If(
147                        test=ast.Compare(
148                            left=ast.Name(id="__name__", ctx=ast.Load()),
149                            ops=[ast.Eq()],
150                            comparators=[ast.Constant(value="__main__")],
151                        ),
152                        body=[
153                            ast.Import(names=[ast.alias(name="sys", asname=None)]),
154                            ast.Expr(
155                                value=ast.Call(
156                                    func=ast.Name(id="print", ctx=ast.Load()),
157                                    args=[
158                                        ast.Call(
159                                            func=ast.Name(id="l2", ctx=ast.Load()),
160                                            args=[
161                                                ast.Call(
162                                                    func=ast.Name(id="int", ctx=ast.Load()),
163                                                    args=[
164                                                        ast.Subscript(
165                                                            value=ast.Attribute(
166                                                                value=ast.Name(id="sys", ctx=ast.Load()),
167                                                                attr="argv",
168                                                                ctx=ast.Load(),
169                                                            ),
170                                                            slice=ast.Constant(value=i + 1),
171                                                        )
172                                                    ],
173                                                )
174                                                for i, _ in enumerate(parameters)
175                                            ],
176                                        )
177                                    ],
178                                )
179                            ),
180                        ],
181                    ),
182                ]
183            )
184
185            ast.fix_missing_locations(module)
186
187            return ast.unparse(module)
def to_ast_term(term: Term) -> ast.expr:
 24def to_ast_term(
 25    term: Term,
 26) -> ast.expr:
 27    _term = partial(to_ast_term)
 28
 29    match term:
 30        case Let(bindings=bindings, body=body):
 31            return ast.Subscript(
 32                value=ast.Tuple(
 33                    elts=[
 34                        *[
 35                            ast.NamedExpr(target=ast.Name(id=encode(name), ctx=ast.Store()), value=_term(value))
 36                            for name, value in bindings
 37                        ],
 38                        _term(body),
 39                    ],
 40                    ctx=ast.Load(),
 41                ),
 42                slice=ast.Constant(-1),
 43                ctx=ast.Load(),
 44            )
 45
 46        case Reference(name=name):
 47            return ast.Name(id=encode(name), ctx=ast.Load())
 48
 49        case Abstract(parameters=parameters, body=body):
 50            return ast.Lambda(
 51                args=ast.arguments(args=[ast.arg(arg=parameter) for parameter in parameters]),
 52                body=_term(body),
 53            )
 54
 55        case Apply(target=target, arguments=arguments):
 56            return ast.Call(
 57                func=_term(target),
 58                args=[_term(argument) for argument in arguments],
 59            )
 60
 61        case Immediate(value=value):
 62            return ast.Constant(value=value)
 63
 64        case Primitive(operator=operator, left=left, right=right):
 65            match operator:
 66                case "+":
 67                    op = ast.Add()
 68
 69                case "-":
 70                    op = ast.Sub()
 71
 72                case "*":  # pragma: no branch
 73                    op = ast.Mult()
 74
 75            return ast.BinOp(left=_term(left), op=op, right=_term(right))
 76
 77        case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
 78            match operator:
 79                case "<":
 80                    op = ast.Lt()
 81
 82                case "==":  # pragma: no branch
 83                    op = ast.Eq()
 84
 85            return ast.IfExp(
 86                test=ast.Compare(left=_term(left), ops=[op], comparators=[_term(right)]),
 87                body=_term(consequent),
 88                orelse=_term(otherwise),
 89            )
 90
 91        case Allocate(count=count):
 92            return ast.List(
 93                elts=[ast.Constant(None) for _ in range(count)],
 94                ctx=ast.Load(),
 95            )
 96
 97        case Load(base=base, index=index):
 98            return ast.Call(
 99                func=ast.Attribute(value=_term(base), attr="__getitem__", ctx=ast.Load()),
100                args=[ast.Constant(value=index)],
101            )
102
103        case Store(base=base, index=index, value=value):
104            return ast.Subscript(
105                value=ast.Tuple(
106                    elts=[
107                        ast.Call(
108                            func=ast.Attribute(value=_term(base), attr="__setitem__", ctx=ast.Load()),
109                            args=[ast.Constant(value=index), _term(value)],
110                        ),
111                        ast.Constant(value=0),
112                    ],
113                    ctx=ast.Load(),
114                ),
115                slice=ast.Constant(-1),
116                ctx=ast.Load(),
117            )
118
119        case Begin(effects=effects, value=value):  # pragma: no branch
120            return ast.Subscript(
121                value=ast.Tuple(
122                    elts=[
123                        *[_term(effect) for effect in effects],
124                        _term(value),
125                    ],
126                    ctx=ast.Load(),
127                ),
128                slice=ast.Constant(-1),
129                ctx=ast.Load(),
130            )
def to_ast_program(program: L2.syntax.Program) -> str:
133def to_ast_program(
134    program: Program,
135) -> str:
136    match program:
137        case Program(parameters=parameters, body=body):  # pragma: no branch
138            module = ast.Module(
139                body=[
140                    ast.FunctionDef(
141                        name="l2",
142                        args=ast.arguments(args=[ast.arg(arg=parameter) for parameter in parameters]),
143                        body=[
144                            ast.Return(value=to_ast_term(body)),
145                        ],
146                    ),
147                    ast.If(
148                        test=ast.Compare(
149                            left=ast.Name(id="__name__", ctx=ast.Load()),
150                            ops=[ast.Eq()],
151                            comparators=[ast.Constant(value="__main__")],
152                        ),
153                        body=[
154                            ast.Import(names=[ast.alias(name="sys", asname=None)]),
155                            ast.Expr(
156                                value=ast.Call(
157                                    func=ast.Name(id="print", ctx=ast.Load()),
158                                    args=[
159                                        ast.Call(
160                                            func=ast.Name(id="l2", ctx=ast.Load()),
161                                            args=[
162                                                ast.Call(
163                                                    func=ast.Name(id="int", ctx=ast.Load()),
164                                                    args=[
165                                                        ast.Subscript(
166                                                            value=ast.Attribute(
167                                                                value=ast.Name(id="sys", ctx=ast.Load()),
168                                                                attr="argv",
169                                                                ctx=ast.Load(),
170                                                            ),
171                                                            slice=ast.Constant(value=i + 1),
172                                                        )
173                                                    ],
174                                                )
175                                                for i, _ in enumerate(parameters)
176                                            ],
177                                        )
178                                    ],
179                                )
180                            ),
181                        ],
182                    ),
183                ]
184            )
185
186            ast.fix_missing_locations(module)
187
188            return ast.unparse(module)