Adds a separte typing system #3

Closed
jbwdevries wants to merge 18 commits from milner_type_checking into master
5 changed files with 103 additions and 44 deletions
Showing only changes of commit 7669f3cbca - Show all commits

View File

@ -368,7 +368,7 @@ class Function:
""" """
A function processes input and produces output A function processes input and produces output
""" """
__slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'posonlyargs', ) __slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'returns_type_var', 'posonlyargs', )
name: str name: str
lineno: int lineno: int
@ -376,6 +376,7 @@ class Function:
imported: bool imported: bool
statements: List[Statement] statements: List[Statement]
returns: TypeBase returns: TypeBase
returns_type_var: Optional[TypeVar]
posonlyargs: List[FunctionParam] posonlyargs: List[FunctionParam]
def __init__(self, name: str, lineno: int) -> None: def __init__(self, name: str, lineno: int) -> None:
@ -385,6 +386,7 @@ class Function:
self.imported = False self.imported = False
self.statements = [] self.statements = []
self.returns = TypeNone() self.returns = TypeNone()
self.returns_type_var = None
self.posonlyargs = [] self.posonlyargs = []
class StructConstructor(Function): class StructConstructor(Function):

View File

@ -564,8 +564,8 @@ class OurVisitor:
func = module.functions[node.func.id] func = module.functions[node.func.id]
if func.returns != exp_type: # if func.returns != exp_type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') # _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
if len(func.posonlyargs) != len(node.args): if len(func.posonlyargs) != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given') _raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given')
@ -700,47 +700,47 @@ class OurVisitor:
_not_implemented(node.kind is None, 'Constant.kind') _not_implemented(node.kind is None, 'Constant.kind')
if isinstance(exp_type, TypeUInt8): if isinstance(exp_type, TypeUInt8):
if not isinstance(node.value, int): # if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value') # _raise_static_error(node, 'Expected integer value')
#
if node.value < 0 or node.value > 255: # if node.value < 0 or node.value > 255:
_raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}') # _raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}')
return ConstantUInt8(exp_type, node.value) return ConstantUInt8(exp_type, node.value)
if isinstance(exp_type, TypeUInt32): if isinstance(exp_type, TypeUInt32):
if not isinstance(node.value, int): # if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value') # _raise_static_error(node, 'Expected integer value')
#
if node.value < 0 or node.value > 4294967295: # if node.value < 0 or node.value > 4294967295:
_raise_static_error(node, 'Integer value out of range') # _raise_static_error(node, 'Integer value out of range')
return ConstantUInt32(exp_type, node.value) return ConstantUInt32(exp_type, node.value)
if isinstance(exp_type, TypeUInt64): if isinstance(exp_type, TypeUInt64):
if not isinstance(node.value, int): # if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value') # _raise_static_error(node, 'Expected integer value')
#
if node.value < 0 or node.value > 18446744073709551615: # if node.value < 0 or node.value > 18446744073709551615:
_raise_static_error(node, 'Integer value out of range') # _raise_static_error(node, 'Integer value out of range')
return ConstantUInt64(exp_type, node.value) return ConstantUInt64(exp_type, node.value)
if isinstance(exp_type, TypeInt32): if isinstance(exp_type, TypeInt32):
if not isinstance(node.value, int): # if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value') # _raise_static_error(node, 'Expected integer value')
#
if node.value < -2147483648 or node.value > 2147483647: # if node.value < -2147483648 or node.value > 2147483647:
_raise_static_error(node, 'Integer value out of range') # _raise_static_error(node, 'Integer value out of range')
return ConstantInt32(exp_type, node.value) return ConstantInt32(exp_type, node.value)
if isinstance(exp_type, TypeInt64): if isinstance(exp_type, TypeInt64):
if not isinstance(node.value, int): # if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value') # _raise_static_error(node, 'Expected integer value')
#
if node.value < -9223372036854775808 or node.value > 9223372036854775807: # if node.value < -9223372036854775808 or node.value > 9223372036854775807:
_raise_static_error(node, 'Integer value out of range') # _raise_static_error(node, 'Integer value out of range')
return ConstantInt64(exp_type, node.value) return ConstantInt64(exp_type, node.value)

View File

@ -9,22 +9,23 @@ def phasm_type(inp: ourlang.Module) -> None:
module(inp) module(inp)
def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar': def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar':
if getattr(inp, 'value', int): value = getattr(inp, 'value', None)
if isinstance(value, int):
result = ctx.new_var() result = ctx.new_var()
# Need at least this many bits to store this constant value # Need at least this many bits to store this constant value
result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2)) # type: ignore result.add_constraint(TypeConstraintBitWidth(minb=len(bin(value)) - 2))
# Don't dictate anything about signedness - you can use a signed # Don't dictate anything about signedness - you can use a signed
# constant in an unsigned variable if the bits fit # constant in an unsigned variable if the bits fit
result.add_constraint(TypeConstraintSigned(None)) result.add_constraint(TypeConstraintSigned(None))
result.add_location(str(inp.value)) # type: ignore result.add_location(str(value))
inp.type_var = result inp.type_var = result
return result return result
raise NotImplementedError(constant, inp) raise NotImplementedError(constant, inp, value)
def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
if isinstance(inp, ourlang.Constant): if isinstance(inp, ourlang.Constant):
@ -43,58 +44,68 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
ctx.unify(left, right) ctx.unify(left, right)
return left return left
if isinstance(inp, ourlang.FunctionCall):
assert inp.function.returns_type_var is not None
if inp.function.posonlyargs:
raise NotImplementedError
return inp.function.returns_type_var
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def function(ctx: 'Context', inp: ourlang.Function) -> None: def function(ctx: 'Context', inp: ourlang.Function) -> None:
for param in inp.posonlyargs:
param.type_var = _convert_old_type(ctx, param.type)
if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn): if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn):
raise NotImplementedError('Functions with not just a return statement') raise NotImplementedError('Functions with not just a return statement')
typ = expression(ctx, inp.statements[0].value) typ = expression(ctx, inp.statements[0].value)
ctx.unify(_convert_old_type(ctx, inp.returns), typ) assert inp.returns_type_var is not None
ctx.unify(inp.returns_type_var, typ)
return return
def module(inp: ourlang.Module) -> None: def module(inp: ourlang.Module) -> None:
ctx = Context() ctx = Context()
for func in inp.functions.values():
func.returns_type_var = _convert_old_type(ctx, func.returns, f'{func.name}.(returns)')
for param in func.posonlyargs:
param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}')
for func in inp.functions.values(): for func in inp.functions.values():
function(ctx, func) function(ctx, func)
from . import typing from . import typing
def _convert_old_type(ctx: Context, inp: typing.TypeBase) -> TypeVar: def _convert_old_type(ctx: Context, inp: typing.TypeBase, location: str) -> TypeVar:
result = ctx.new_var() result = ctx.new_var()
if isinstance(inp, typing.TypeUInt8): if isinstance(inp, typing.TypeUInt8):
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False)) result.add_constraint(TypeConstraintSigned(False))
result.add_location('u8') result.add_location(location)
return result return result
if isinstance(inp, typing.TypeUInt32): if isinstance(inp, typing.TypeUInt32):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False)) result.add_constraint(TypeConstraintSigned(False))
result.add_location('u32') result.add_location(location)
return result return result
if isinstance(inp, typing.TypeUInt64): if isinstance(inp, typing.TypeUInt64):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False)) result.add_constraint(TypeConstraintSigned(False))
result.add_location('u64') result.add_location(location)
return result return result
if isinstance(inp, typing.TypeInt32): if isinstance(inp, typing.TypeInt32):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True)) result.add_constraint(TypeConstraintSigned(True))
result.add_location('i32') result.add_location(location)
return result return result
if isinstance(inp, typing.TypeInt64): if isinstance(inp, typing.TypeInt64):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True)) result.add_constraint(TypeConstraintSigned(True))
result.add_location('i64') result.add_location(location)
return result return result
raise NotImplementedError(_convert_old_type, inp) raise NotImplementedError(_convert_old_type, inp)

View File

@ -304,6 +304,21 @@ def testEntry(a: i32, b: i32) -> i32:
assert 1 == suite.run_code(10, 20).returned_value assert 1 == suite.run_code(10, 20).returned_value
assert 0 == suite.run_code(10, 10).returned_value assert 0 == suite.run_code(10, 10).returned_value
@pytest.mark.integration_test
def test_call_no_args():
code_py = """
def helper() -> i32:
return 19
@exported
def testEntry() -> i32:
return helper()
"""
result = Suite(code_py).run_code()
assert 19 == result.returned_value
@pytest.mark.integration_test @pytest.mark.integration_test
def test_call_pre_defined(): def test_call_pre_defined():
code_py = """ code_py = """

View File

@ -0,0 +1,31 @@
import pytest
from phasm.parser import phasm_parse
from phasm.typer import phasm_type
from phasm.exceptions import TypingError
@pytest.mark.integration_test
def test_constant_too_wide():
code_py = """
def func_const() -> u8:
return 0xFFF
"""
ast = phasm_parse(code_py)
with pytest.raises(TypingError, match='Other min bitwidth exceeds max bitwidth'):
phasm_type(ast)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', [32, 64])
def test_signed_mismatch(type_):
code_py = f"""
def func_const() -> u{type_}:
return 0
def func_call() -> i{type_}:
return func_const()
"""
ast = phasm_parse(code_py)
with pytest.raises(TypingError, match='Signed does not matchq'):
phasm_type(ast)