L3.to_python
1import ast 2from functools import partial 3 4from util.encode import encode 5 6from .syntax import ( 7 Abstract, 8 Allocate, 9 Apply, 10 Begin, 11 Branch, 12 Immediate, 13 Let, 14 LetRec, 15 Load, 16 Primitive, 17 Program, 18 Reference, 19 Store, 20 Term, 21) 22 23 24def to_ast_term( 25 term: Term, 26) -> ast.expr: 27 _term = partial(to_ast_term) 28 29 match term: 30 case Let(bindings=bindings, body=body): 31 return ast.Subscript( 32 value=ast.Tuple( 33 elts=[ 34 *[ 35 ast.NamedExpr(target=ast.Name(id=encode(name), ctx=ast.Store()), value=_term(value)) 36 for name, value in bindings 37 ], 38 _term(body), 39 ], 40 ctx=ast.Load(), 41 ), 42 slice=ast.Constant(-1), 43 ctx=ast.Load(), 44 ) 45 46 case LetRec(bindings=bindings, body=body): 47 return ast.Subscript( 48 value=ast.Tuple( 49 elts=[ 50 *[ 51 ast.NamedExpr(target=ast.Name(id=encode(name), ctx=ast.Store()), value=ast.Constant(None)) 52 for name, _value in bindings 53 ], 54 *[ 55 ast.NamedExpr(target=ast.Name(id=encode(name), ctx=ast.Store()), value=_term(value)) 56 for name, value in bindings 57 ], 58 _term(body), 59 ], 60 ctx=ast.Load(), 61 ), 62 slice=ast.Constant(-1), 63 ctx=ast.Load(), 64 ) 65 66 case Reference(name=name): 67 return ast.Name(id=encode(name), ctx=ast.Load()) 68 69 case Abstract(parameters=parameters, body=body): 70 return ast.Lambda( 71 args=ast.arguments(args=[ast.arg(arg=encode(parameter)) for parameter in parameters]), 72 body=_term(body), 73 ) 74 75 case Apply(target=target, arguments=arguments): 76 return ast.Call( 77 func=_term(target), 78 args=[_term(argument) for argument in arguments], 79 ) 80 81 case Immediate(value=value): 82 return ast.Constant(value=value) 83 84 case Primitive(operator=operator, left=left, right=right): 85 match operator: 86 case "+": 87 op = ast.Add() 88 89 case "-": 90 op = ast.Sub() 91 92 case "*": # pragma: no branch 93 op = ast.Mult() 94 95 return ast.BinOp( 96 left=_term(left), 97 op=op, 98 right=_term(right), 99 ) 100 101 case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise): 102 match operator: 103 case "<": 104 op = ast.Lt() 105 106 case "==": # pragma: no branch 107 op = ast.Eq() 108 109 return ast.IfExp( 110 test=ast.Compare( 111 left=_term(left), 112 ops=[op], 113 comparators=[_term(right)], 114 ), 115 body=_term(consequent), 116 orelse=_term(otherwise), 117 ) 118 119 case Allocate(count=count): 120 return ast.List( 121 elts=[ast.Constant(None) for _ in range(count)], 122 ctx=ast.Load(), 123 ) 124 125 case Load(base=base, index=index): 126 return ast.Call( 127 func=ast.Attribute(value=_term(base), attr="__getitem__", ctx=ast.Load()), 128 args=[ast.Constant(value=index)], 129 ) 130 131 case Store(base=base, index=index, value=value): 132 return ast.Subscript( 133 value=ast.Tuple( 134 elts=[ 135 ast.Call( 136 func=ast.Attribute(value=_term(base), attr="__setitem__", ctx=ast.Load()), 137 args=[ast.Constant(value=index), _term(value)], 138 ), 139 ast.Constant(value=0), 140 ], 141 ctx=ast.Load(), 142 ), 143 slice=ast.Constant(-1), 144 ctx=ast.Load(), 145 ) 146 147 case Begin(effects=effects, value=value): # pragma: no branch 148 return ast.Subscript( 149 value=ast.Tuple( 150 elts=[ 151 *[_term(effect) for effect in effects], 152 _term(value), 153 ], 154 ctx=ast.Load(), 155 ), 156 slice=ast.Constant(-1), 157 ctx=ast.Load(), 158 ) 159 160 161def to_ast_program( 162 program: Program, 163) -> str: 164 match program: 165 case Program(parameters=parameters, body=body): # pragma: no branch 166 module = ast.Module( 167 body=[ 168 ast.FunctionDef( 169 name="l3", 170 args=ast.arguments(args=[ast.arg(arg=encode(parameter)) for parameter in parameters]), 171 body=[ 172 ast.Return(value=to_ast_term(body)), 173 ], 174 ), 175 ast.If( 176 test=ast.Compare( 177 left=ast.Name(id="__name__", ctx=ast.Load()), 178 ops=[ast.Eq()], 179 comparators=[ast.Constant(value="__main__")], 180 ), 181 body=[ 182 ast.Import(names=[ast.alias(name="sys", asname=None)]), 183 ast.Expr( 184 value=ast.Call( 185 func=ast.Name(id="print", ctx=ast.Load()), 186 args=[ 187 ast.Call( 188 func=ast.Name(id="l3", ctx=ast.Load()), 189 args=[ 190 ast.Call( 191 func=ast.Name(id="int", ctx=ast.Load()), 192 args=[ 193 ast.Subscript( 194 value=ast.Attribute( 195 value=ast.Name(id="sys", ctx=ast.Load()), 196 attr="argv", 197 ctx=ast.Load(), 198 ), 199 slice=ast.Constant(value=i + 1), 200 ) 201 ], 202 ) 203 for i, _ in enumerate(parameters) 204 ], 205 ) 206 ], 207 ) 208 ), 209 ], 210 ), 211 ] 212 ) 213 214 ast.fix_missing_locations(module) 215 216 return ast.unparse(module)
def
to_ast_term(term: Term) -> ast.expr:
25def to_ast_term( 26 term: Term, 27) -> ast.expr: 28 _term = partial(to_ast_term) 29 30 match term: 31 case Let(bindings=bindings, body=body): 32 return ast.Subscript( 33 value=ast.Tuple( 34 elts=[ 35 *[ 36 ast.NamedExpr(target=ast.Name(id=encode(name), ctx=ast.Store()), value=_term(value)) 37 for name, value in bindings 38 ], 39 _term(body), 40 ], 41 ctx=ast.Load(), 42 ), 43 slice=ast.Constant(-1), 44 ctx=ast.Load(), 45 ) 46 47 case LetRec(bindings=bindings, body=body): 48 return ast.Subscript( 49 value=ast.Tuple( 50 elts=[ 51 *[ 52 ast.NamedExpr(target=ast.Name(id=encode(name), ctx=ast.Store()), value=ast.Constant(None)) 53 for name, _value in bindings 54 ], 55 *[ 56 ast.NamedExpr(target=ast.Name(id=encode(name), ctx=ast.Store()), value=_term(value)) 57 for name, value in bindings 58 ], 59 _term(body), 60 ], 61 ctx=ast.Load(), 62 ), 63 slice=ast.Constant(-1), 64 ctx=ast.Load(), 65 ) 66 67 case Reference(name=name): 68 return ast.Name(id=encode(name), ctx=ast.Load()) 69 70 case Abstract(parameters=parameters, body=body): 71 return ast.Lambda( 72 args=ast.arguments(args=[ast.arg(arg=encode(parameter)) for parameter in parameters]), 73 body=_term(body), 74 ) 75 76 case Apply(target=target, arguments=arguments): 77 return ast.Call( 78 func=_term(target), 79 args=[_term(argument) for argument in arguments], 80 ) 81 82 case Immediate(value=value): 83 return ast.Constant(value=value) 84 85 case Primitive(operator=operator, left=left, right=right): 86 match operator: 87 case "+": 88 op = ast.Add() 89 90 case "-": 91 op = ast.Sub() 92 93 case "*": # pragma: no branch 94 op = ast.Mult() 95 96 return ast.BinOp( 97 left=_term(left), 98 op=op, 99 right=_term(right), 100 ) 101 102 case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise): 103 match operator: 104 case "<": 105 op = ast.Lt() 106 107 case "==": # pragma: no branch 108 op = ast.Eq() 109 110 return ast.IfExp( 111 test=ast.Compare( 112 left=_term(left), 113 ops=[op], 114 comparators=[_term(right)], 115 ), 116 body=_term(consequent), 117 orelse=_term(otherwise), 118 ) 119 120 case Allocate(count=count): 121 return ast.List( 122 elts=[ast.Constant(None) for _ in range(count)], 123 ctx=ast.Load(), 124 ) 125 126 case Load(base=base, index=index): 127 return ast.Call( 128 func=ast.Attribute(value=_term(base), attr="__getitem__", ctx=ast.Load()), 129 args=[ast.Constant(value=index)], 130 ) 131 132 case Store(base=base, index=index, value=value): 133 return ast.Subscript( 134 value=ast.Tuple( 135 elts=[ 136 ast.Call( 137 func=ast.Attribute(value=_term(base), attr="__setitem__", ctx=ast.Load()), 138 args=[ast.Constant(value=index), _term(value)], 139 ), 140 ast.Constant(value=0), 141 ], 142 ctx=ast.Load(), 143 ), 144 slice=ast.Constant(-1), 145 ctx=ast.Load(), 146 ) 147 148 case Begin(effects=effects, value=value): # pragma: no branch 149 return ast.Subscript( 150 value=ast.Tuple( 151 elts=[ 152 *[_term(effect) for effect in effects], 153 _term(value), 154 ], 155 ctx=ast.Load(), 156 ), 157 slice=ast.Constant(-1), 158 ctx=ast.Load(), 159 )
162def to_ast_program( 163 program: Program, 164) -> str: 165 match program: 166 case Program(parameters=parameters, body=body): # pragma: no branch 167 module = ast.Module( 168 body=[ 169 ast.FunctionDef( 170 name="l3", 171 args=ast.arguments(args=[ast.arg(arg=encode(parameter)) for parameter in parameters]), 172 body=[ 173 ast.Return(value=to_ast_term(body)), 174 ], 175 ), 176 ast.If( 177 test=ast.Compare( 178 left=ast.Name(id="__name__", ctx=ast.Load()), 179 ops=[ast.Eq()], 180 comparators=[ast.Constant(value="__main__")], 181 ), 182 body=[ 183 ast.Import(names=[ast.alias(name="sys", asname=None)]), 184 ast.Expr( 185 value=ast.Call( 186 func=ast.Name(id="print", ctx=ast.Load()), 187 args=[ 188 ast.Call( 189 func=ast.Name(id="l3", ctx=ast.Load()), 190 args=[ 191 ast.Call( 192 func=ast.Name(id="int", ctx=ast.Load()), 193 args=[ 194 ast.Subscript( 195 value=ast.Attribute( 196 value=ast.Name(id="sys", ctx=ast.Load()), 197 attr="argv", 198 ctx=ast.Load(), 199 ), 200 slice=ast.Constant(value=i + 1), 201 ) 202 ], 203 ) 204 for i, _ in enumerate(parameters) 205 ], 206 ) 207 ], 208 ) 209 ), 210 ], 211 ), 212 ] 213 ) 214 215 ast.fix_missing_locations(module) 216 217 return ast.unparse(module)