phasm/phasm/typer.py
2022-09-16 16:43:40 +02:00

101 lines
3.3 KiB
Python

"""
Type checks and enriches the given ast
"""
from . import ourlang
from .typing import Context, TypeConstraintBitWidth, TypeConstraintSigned, TypeVar
def phasm_type(inp: ourlang.Module) -> None:
module(inp)
def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar':
if getattr(inp, 'value', int):
result = ctx.new_var()
# Need at least this many bits to store this constant value
result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2)) # type: ignore
# Don't dictate anything about signedness - you can use a signed
# constant in an unsigned variable if the bits fit
result.add_constraint(TypeConstraintSigned(None))
result.add_location(str(inp.value)) # type: ignore
inp.type_var = result
return result
raise NotImplementedError(constant, inp)
def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
if isinstance(inp, ourlang.Constant):
return constant(ctx, inp)
if isinstance(inp, ourlang.VariableReference):
assert inp.variable.type_var is not None, inp
return inp.variable.type_var
if isinstance(inp, ourlang.BinaryOp):
if inp.operator not in ('+', '-', '|', '&', '^'):
raise NotImplementedError(expression, inp, inp.operator)
left = expression(ctx, inp.left)
right = expression(ctx, inp.right)
ctx.unify(left, right)
return left
raise NotImplementedError(expression, inp)
def function(ctx: 'Context', inp: ourlang.Function) -> None:
for param in inp.posonlyargs:
param.type_var = _convert_old_type(ctx, param.type)
if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn):
raise NotImplementedError('Functions with not just a return statement')
typ = expression(ctx, inp.statements[0].value)
ctx.unify(_convert_old_type(ctx, inp.returns), typ)
return
def module(inp: ourlang.Module) -> None:
ctx = Context()
for func in inp.functions.values():
function(ctx, func)
from . import typing
def _convert_old_type(ctx: Context, inp: typing.TypeBase) -> TypeVar:
result = ctx.new_var()
if isinstance(inp, typing.TypeUInt8):
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False))
result.add_location('u8')
return result
if isinstance(inp, typing.TypeUInt32):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False))
result.add_location('u32')
return result
if isinstance(inp, typing.TypeUInt64):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64))
result.add_constraint(TypeConstraintSigned(False))
result.add_location('u64')
return result
if isinstance(inp, typing.TypeInt32):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True))
result.add_location('i32')
return result
if isinstance(inp, typing.TypeInt64):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64))
result.add_constraint(TypeConstraintSigned(True))
result.add_location('i64')
return result
raise NotImplementedError(_convert_old_type, inp)