L1.to_python
1import ast 2from functools import partial 3 4from util.encode import encode 5 6from .syntax import ( 7 Abstract, 8 Allocate, 9 Apply, 10 Branch, 11 Copy, 12 Halt, 13 Immediate, 14 Load, 15 Primitive, 16 Program, 17 Statement, 18 Store, 19) 20 21 22def load(name: str) -> ast.Name: 23 return ast.Name(id=encode(name), ctx=ast.Load()) 24 25 26def store(name: str) -> ast.Name: 27 return ast.Name(id=encode(name), ctx=ast.Store()) 28 29 30def to_ast_statement( 31 statement: Statement, 32) -> list[ast.stmt]: 33 _statement = partial(to_ast_statement) 34 35 match statement: 36 case Copy(destination=destination, source=source, then=then): 37 return [ 38 ast.Assign(targets=[store(destination)], value=load(source)), 39 *_statement(then), 40 ] 41 42 case Abstract(destination=destination, parameters=parameters, body=body, then=then): 43 return [ 44 ast.FunctionDef( 45 name=encode(destination), 46 args=ast.arguments(args=[ast.arg(arg=parameter) for parameter in parameters]), 47 body=_statement(body), 48 ), 49 *_statement(then), 50 ] 51 52 case Apply(target=target, arguments=arguments): 53 return [ 54 ast.Return( 55 ast.Call( 56 func=load(target), 57 args=[load(argument) for argument in arguments], 58 ) 59 ) 60 ] 61 62 case Immediate(destination=destination, value=value, then=then): 63 return [ 64 ast.Assign(targets=[store(destination)], value=ast.Constant(value=value)), 65 *_statement(then), 66 ] 67 68 case Primitive(destination=destination, operator=operator, left=left, right=right, then=then): 69 match operator: 70 case "+": 71 op = ast.Add() 72 73 case "-": 74 op = ast.Sub() 75 76 case "*": # pragma: no branch 77 op = ast.Mult() 78 79 return [ 80 ast.Assign( 81 targets=[store(destination)], 82 value=ast.BinOp(left=load(left), op=op, right=load(right)), 83 ), 84 *_statement(then), 85 ] 86 87 case Branch(operator=operator, left=left, right=right, then=then, otherwise=otherwise): 88 match operator: 89 case "<": 90 op = ast.Lt() 91 92 case "==": # pragma: no branch 93 op = ast.Eq() 94 95 return [ 96 ast.If( 97 ast.Compare(left=load(left), ops=[op], comparators=[load(right)]), 98 body=_statement(then), 99 orelse=_statement(otherwise), 100 ), 101 ] 102 103 case Allocate(destination=destination, count=count, then=then): 104 return [ 105 ast.Assign( 106 targets=[store(destination)], 107 value=ast.List( 108 elts=[ast.Constant(None) for _ in range(count)], 109 ctx=ast.Load(), 110 ), 111 ), 112 *_statement(then), 113 ] 114 115 case Load(destination=destination, base=base, index=index, then=then): 116 return [ 117 ast.Assign( 118 targets=[store(destination)], 119 value=ast.Subscript( 120 value=load(base), 121 slice=ast.Constant(index), 122 ctx=ast.Load(), 123 ), 124 ), 125 *_statement(then), 126 ] 127 128 case Store(base=base, index=index, value=value, then=then): 129 return [ 130 ast.Assign( 131 targets=[ 132 ast.Subscript( 133 value=load(base), 134 slice=ast.Constant(index), 135 ctx=ast.Store(), 136 ) 137 ], 138 value=load(value), 139 ), 140 *_statement(then), 141 ] 142 143 case Halt(value=value): # pragma: no branch 144 return [ 145 ast.Return(value=load(value)), 146 ] 147 148 149def to_ast_program( 150 program: Program, 151) -> str: 152 match program: 153 case Program(parameters=parameters, body=body): # pragma: no branch 154 module = ast.Module( 155 body=[ 156 ast.FunctionDef( 157 name="l1", 158 args=ast.arguments(args=[ast.arg(arg=parameter) for parameter in parameters]), 159 body=to_ast_statement(body), 160 ), 161 ast.If( 162 test=ast.Compare( 163 left=ast.Name(id="__name__", ctx=ast.Load()), 164 ops=[ast.Eq()], 165 comparators=[ast.Constant(value="__main__")], 166 ), 167 body=[ 168 ast.Import(names=[ast.alias(name="sys", asname=None)]), 169 ast.Expr( 170 value=ast.Call( 171 func=ast.Name(id="print", ctx=ast.Load()), 172 args=[ 173 ast.Call( 174 func=ast.Name(id="l1", ctx=ast.Load()), 175 args=[ 176 ast.Call( 177 func=ast.Name(id="int", ctx=ast.Load()), 178 args=[ 179 ast.Subscript( 180 value=ast.Attribute( 181 value=ast.Name(id="sys", ctx=ast.Load()), 182 attr="argv", 183 ctx=ast.Load(), 184 ), 185 slice=ast.Constant(value=i + 1), 186 ) 187 ], 188 ) 189 for i, _ in enumerate(parameters) 190 ], 191 ) 192 ], 193 ) 194 ), 195 ], 196 ), 197 ] 198 ) 199 200 ast.fix_missing_locations(module) 201 202 return ast.unparse(module)
def
load(name: str) -> ast.Name:
def
store(name: str) -> ast.Name:
def
to_ast_statement(statement: Statement) -> list[ast.stmt]:
31def to_ast_statement( 32 statement: Statement, 33) -> list[ast.stmt]: 34 _statement = partial(to_ast_statement) 35 36 match statement: 37 case Copy(destination=destination, source=source, then=then): 38 return [ 39 ast.Assign(targets=[store(destination)], value=load(source)), 40 *_statement(then), 41 ] 42 43 case Abstract(destination=destination, parameters=parameters, body=body, then=then): 44 return [ 45 ast.FunctionDef( 46 name=encode(destination), 47 args=ast.arguments(args=[ast.arg(arg=parameter) for parameter in parameters]), 48 body=_statement(body), 49 ), 50 *_statement(then), 51 ] 52 53 case Apply(target=target, arguments=arguments): 54 return [ 55 ast.Return( 56 ast.Call( 57 func=load(target), 58 args=[load(argument) for argument in arguments], 59 ) 60 ) 61 ] 62 63 case Immediate(destination=destination, value=value, then=then): 64 return [ 65 ast.Assign(targets=[store(destination)], value=ast.Constant(value=value)), 66 *_statement(then), 67 ] 68 69 case Primitive(destination=destination, operator=operator, left=left, right=right, then=then): 70 match operator: 71 case "+": 72 op = ast.Add() 73 74 case "-": 75 op = ast.Sub() 76 77 case "*": # pragma: no branch 78 op = ast.Mult() 79 80 return [ 81 ast.Assign( 82 targets=[store(destination)], 83 value=ast.BinOp(left=load(left), op=op, right=load(right)), 84 ), 85 *_statement(then), 86 ] 87 88 case Branch(operator=operator, left=left, right=right, then=then, otherwise=otherwise): 89 match operator: 90 case "<": 91 op = ast.Lt() 92 93 case "==": # pragma: no branch 94 op = ast.Eq() 95 96 return [ 97 ast.If( 98 ast.Compare(left=load(left), ops=[op], comparators=[load(right)]), 99 body=_statement(then), 100 orelse=_statement(otherwise), 101 ), 102 ] 103 104 case Allocate(destination=destination, count=count, then=then): 105 return [ 106 ast.Assign( 107 targets=[store(destination)], 108 value=ast.List( 109 elts=[ast.Constant(None) for _ in range(count)], 110 ctx=ast.Load(), 111 ), 112 ), 113 *_statement(then), 114 ] 115 116 case Load(destination=destination, base=base, index=index, then=then): 117 return [ 118 ast.Assign( 119 targets=[store(destination)], 120 value=ast.Subscript( 121 value=load(base), 122 slice=ast.Constant(index), 123 ctx=ast.Load(), 124 ), 125 ), 126 *_statement(then), 127 ] 128 129 case Store(base=base, index=index, value=value, then=then): 130 return [ 131 ast.Assign( 132 targets=[ 133 ast.Subscript( 134 value=load(base), 135 slice=ast.Constant(index), 136 ctx=ast.Store(), 137 ) 138 ], 139 value=load(value), 140 ), 141 *_statement(then), 142 ] 143 144 case Halt(value=value): # pragma: no branch 145 return [ 146 ast.Return(value=load(value)), 147 ]
150def to_ast_program( 151 program: Program, 152) -> str: 153 match program: 154 case Program(parameters=parameters, body=body): # pragma: no branch 155 module = ast.Module( 156 body=[ 157 ast.FunctionDef( 158 name="l1", 159 args=ast.arguments(args=[ast.arg(arg=parameter) for parameter in parameters]), 160 body=to_ast_statement(body), 161 ), 162 ast.If( 163 test=ast.Compare( 164 left=ast.Name(id="__name__", ctx=ast.Load()), 165 ops=[ast.Eq()], 166 comparators=[ast.Constant(value="__main__")], 167 ), 168 body=[ 169 ast.Import(names=[ast.alias(name="sys", asname=None)]), 170 ast.Expr( 171 value=ast.Call( 172 func=ast.Name(id="print", ctx=ast.Load()), 173 args=[ 174 ast.Call( 175 func=ast.Name(id="l1", ctx=ast.Load()), 176 args=[ 177 ast.Call( 178 func=ast.Name(id="int", ctx=ast.Load()), 179 args=[ 180 ast.Subscript( 181 value=ast.Attribute( 182 value=ast.Name(id="sys", ctx=ast.Load()), 183 attr="argv", 184 ctx=ast.Load(), 185 ), 186 slice=ast.Constant(value=i + 1), 187 ) 188 ], 189 ) 190 for i, _ in enumerate(parameters) 191 ], 192 ) 193 ], 194 ) 195 ), 196 ], 197 ), 198 ] 199 ) 200 201 ast.fix_missing_locations(module) 202 203 return ast.unparse(module)