L3.check

  1from collections.abc import Mapping
  2from functools import partial
  3from typing import Counter
  4
  5from .syntax import (
  6    Abstract,
  7    Allocate,
  8    Apply,
  9    Begin,
 10    Branch,
 11    Identifier,
 12    Immediate,
 13    Let,
 14    LetRec,
 15    Load,
 16    Primitive,
 17    Program,
 18    Reference,
 19    Store,
 20    Term,
 21)
 22
 23type Context = Mapping[Identifier, None]
 24
 25
 26def check_term(
 27    term: Term,
 28    context: Context,
 29) -> None:
 30    recur = partial(check_term, context=context)  # noqa: F841
 31
 32    match term:
 33        case Let(bindings=bindings, body=body):
 34            counts = Counter(name for name, _ in bindings)
 35            duplicates = {name: count for name, count in counts.items() if count > 1}
 36            if duplicates:
 37                raise ValueError(f"Duplicate bindings: {duplicates}")
 38
 39            for _, value in bindings:
 40                recur(value)
 41
 42            local = dict.fromkeys([name for name, _ in bindings])
 43            recur(body, context={**context, **local})
 44
 45        case LetRec(bindings=bindings, body=body):
 46            counts = Counter(name for name, _ in bindings)
 47            duplicates = {name: count for name, count in counts.items() if count > 1}
 48            if duplicates:
 49                raise ValueError(f"Duplicate bindings: {duplicates}")
 50
 51            local = dict.fromkeys([name for name, _ in bindings])
 52
 53            for name, value in bindings:
 54                recur(value, context={**context, **local})
 55
 56            check_term(body, context={**context, **local})
 57
 58        case Reference(name=name):  # Leaf
 59            if name not in context:
 60                raise ValueError(f"Unbound variable: {name}")
 61
 62        case Abstract(parameters=parameters, body=body):  # Done
 63            counts = Counter(parameters)
 64            duplicates = {name for name, count in counts.items() if count > 1}
 65            if duplicates:
 66                raise ValueError(f"Duplicate parameters: {duplicates}")
 67            local = dict.fromkeys(parameters, None)
 68            check_term(body, context=local)
 69
 70        case Apply(target=target, arguments=arguments):  # Done
 71            recur(target)
 72            for argument in arguments:
 73                recur(argument)
 74
 75        case Immediate(value=_value):  # Leaf
 76            pass
 77
 78        case Primitive(operator=_operator, left=left, right=right):  # Should be done
 79            recur(left)
 80            recur(right)
 81
 82        case Branch(operator=_operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
 83            recur(left)
 84            recur(right)
 85            recur(consequent)
 86            recur(otherwise)
 87
 88        case Allocate(count=_count):
 89            pass  # No need to check count, as it is a non-negative integer by construction
 90
 91        case Load(base=base, index=_index):
 92            recur(base)  # No need to check index, as it is a non-negative integer by construction
 93
 94        case Store(base=base, index=_index, value=value):
 95            recur(base)  # No need to check index, as it is a non-negative integer by construction
 96            recur(value)
 97
 98        case Begin(effects=effects, value=value):  # pragma: no branch
 99            for effect in effects:
100                recur(effect)
101            recur(value)
102
103
104def check_program(program: Program) -> None:
105    match program:
106        case Program(parameters=parameters, body=body):  # pragma: no branch
107            counts = Counter(parameters)
108            duplicates = {name for name, count in counts.items() if count > 1}
109            if duplicates:
110                raise ValueError(f"Duplicate parameters: {duplicates}")
111            local = dict.fromkeys(parameters, None)
112            check_term(body, context=local)
type Context = Mapping[Identifier, None]
def check_term(term: Term, context: Context) -> None:
 27def check_term(
 28    term: Term,
 29    context: Context,
 30) -> None:
 31    recur = partial(check_term, context=context)  # noqa: F841
 32
 33    match term:
 34        case Let(bindings=bindings, body=body):
 35            counts = Counter(name for name, _ in bindings)
 36            duplicates = {name: count for name, count in counts.items() if count > 1}
 37            if duplicates:
 38                raise ValueError(f"Duplicate bindings: {duplicates}")
 39
 40            for _, value in bindings:
 41                recur(value)
 42
 43            local = dict.fromkeys([name for name, _ in bindings])
 44            recur(body, context={**context, **local})
 45
 46        case LetRec(bindings=bindings, body=body):
 47            counts = Counter(name for name, _ in bindings)
 48            duplicates = {name: count for name, count in counts.items() if count > 1}
 49            if duplicates:
 50                raise ValueError(f"Duplicate bindings: {duplicates}")
 51
 52            local = dict.fromkeys([name for name, _ in bindings])
 53
 54            for name, value in bindings:
 55                recur(value, context={**context, **local})
 56
 57            check_term(body, context={**context, **local})
 58
 59        case Reference(name=name):  # Leaf
 60            if name not in context:
 61                raise ValueError(f"Unbound variable: {name}")
 62
 63        case Abstract(parameters=parameters, body=body):  # Done
 64            counts = Counter(parameters)
 65            duplicates = {name for name, count in counts.items() if count > 1}
 66            if duplicates:
 67                raise ValueError(f"Duplicate parameters: {duplicates}")
 68            local = dict.fromkeys(parameters, None)
 69            check_term(body, context=local)
 70
 71        case Apply(target=target, arguments=arguments):  # Done
 72            recur(target)
 73            for argument in arguments:
 74                recur(argument)
 75
 76        case Immediate(value=_value):  # Leaf
 77            pass
 78
 79        case Primitive(operator=_operator, left=left, right=right):  # Should be done
 80            recur(left)
 81            recur(right)
 82
 83        case Branch(operator=_operator, left=left, right=right, consequent=consequent, otherwise=otherwise):
 84            recur(left)
 85            recur(right)
 86            recur(consequent)
 87            recur(otherwise)
 88
 89        case Allocate(count=_count):
 90            pass  # No need to check count, as it is a non-negative integer by construction
 91
 92        case Load(base=base, index=_index):
 93            recur(base)  # No need to check index, as it is a non-negative integer by construction
 94
 95        case Store(base=base, index=_index, value=value):
 96            recur(base)  # No need to check index, as it is a non-negative integer by construction
 97            recur(value)
 98
 99        case Begin(effects=effects, value=value):  # pragma: no branch
100            for effect in effects:
101                recur(effect)
102            recur(value)
def check_program(program: L3.syntax.Program) -> None:
105def check_program(program: Program) -> None:
106    match program:
107        case Program(parameters=parameters, body=body):  # pragma: no branch
108            counts = Counter(parameters)
109            duplicates = {name for name, count in counts.items() if count > 1}
110            if duplicates:
111                raise ValueError(f"Duplicate parameters: {duplicates}")
112            local = dict.fromkeys(parameters, None)
113            check_term(body, context=local)