Adds a separte typing system #3

Closed
jbwdevries wants to merge 18 commits from milner_type_checking into master
5 changed files with 103 additions and 44 deletions
Showing only changes of commit 7669f3cbca - Show all commits

View File

@ -368,7 +368,7 @@ class Function:
"""
A function processes input and produces output
"""
__slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'posonlyargs', )
__slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'returns_type_var', 'posonlyargs', )
name: str
lineno: int
@ -376,6 +376,7 @@ class Function:
imported: bool
statements: List[Statement]
returns: TypeBase
returns_type_var: Optional[TypeVar]
posonlyargs: List[FunctionParam]
def __init__(self, name: str, lineno: int) -> None:
@ -385,6 +386,7 @@ class Function:
self.imported = False
self.statements = []
self.returns = TypeNone()
self.returns_type_var = None
self.posonlyargs = []
class StructConstructor(Function):

View File

@ -564,8 +564,8 @@ class OurVisitor:
func = module.functions[node.func.id]
if func.returns != exp_type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
# if func.returns != exp_type:
# _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
if len(func.posonlyargs) != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given')
@ -700,47 +700,47 @@ class OurVisitor:
_not_implemented(node.kind is None, 'Constant.kind')
if isinstance(exp_type, TypeUInt8):
if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value')
if node.value < 0 or node.value > 255:
_raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}')
# if not isinstance(node.value, int):
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < 0 or node.value > 255:
# _raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}')
return ConstantUInt8(exp_type, node.value)
if isinstance(exp_type, TypeUInt32):
if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value')
if node.value < 0 or node.value > 4294967295:
_raise_static_error(node, 'Integer value out of range')
# if not isinstance(node.value, int):
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < 0 or node.value > 4294967295:
# _raise_static_error(node, 'Integer value out of range')
return ConstantUInt32(exp_type, node.value)
if isinstance(exp_type, TypeUInt64):
if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value')
if node.value < 0 or node.value > 18446744073709551615:
_raise_static_error(node, 'Integer value out of range')
# if not isinstance(node.value, int):
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < 0 or node.value > 18446744073709551615:
# _raise_static_error(node, 'Integer value out of range')
return ConstantUInt64(exp_type, node.value)
if isinstance(exp_type, TypeInt32):
if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value')
if node.value < -2147483648 or node.value > 2147483647:
_raise_static_error(node, 'Integer value out of range')
# if not isinstance(node.value, int):
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < -2147483648 or node.value > 2147483647:
# _raise_static_error(node, 'Integer value out of range')
return ConstantInt32(exp_type, node.value)
if isinstance(exp_type, TypeInt64):
if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value')
if node.value < -9223372036854775808 or node.value > 9223372036854775807:
_raise_static_error(node, 'Integer value out of range')
# if not isinstance(node.value, int):
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < -9223372036854775808 or node.value > 9223372036854775807:
# _raise_static_error(node, 'Integer value out of range')
return ConstantInt64(exp_type, node.value)

View File

@ -9,22 +9,23 @@ def phasm_type(inp: ourlang.Module) -> None:
module(inp)
def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar':
if getattr(inp, 'value', int):
value = getattr(inp, 'value', None)
if isinstance(value, int):
result = ctx.new_var()
# Need at least this many bits to store this constant value
result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2)) # type: ignore
result.add_constraint(TypeConstraintBitWidth(minb=len(bin(value)) - 2))
# Don't dictate anything about signedness - you can use a signed
# constant in an unsigned variable if the bits fit
result.add_constraint(TypeConstraintSigned(None))
result.add_location(str(inp.value)) # type: ignore
result.add_location(str(value))
inp.type_var = result
return result
raise NotImplementedError(constant, inp)
raise NotImplementedError(constant, inp, value)
def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
if isinstance(inp, ourlang.Constant):
@ -43,58 +44,68 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
ctx.unify(left, right)
return left
if isinstance(inp, ourlang.FunctionCall):
assert inp.function.returns_type_var is not None
if inp.function.posonlyargs:
raise NotImplementedError
return inp.function.returns_type_var
raise NotImplementedError(expression, inp)
def function(ctx: 'Context', inp: ourlang.Function) -> None:
for param in inp.posonlyargs:
param.type_var = _convert_old_type(ctx, param.type)
if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn):
raise NotImplementedError('Functions with not just a return statement')
typ = expression(ctx, inp.statements[0].value)
ctx.unify(_convert_old_type(ctx, inp.returns), typ)
assert inp.returns_type_var is not None
ctx.unify(inp.returns_type_var, typ)
return
def module(inp: ourlang.Module) -> None:
ctx = Context()
for func in inp.functions.values():
func.returns_type_var = _convert_old_type(ctx, func.returns, f'{func.name}.(returns)')
for param in func.posonlyargs:
param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}')
for func in inp.functions.values():
function(ctx, func)
from . import typing
def _convert_old_type(ctx: Context, inp: typing.TypeBase) -> TypeVar:
def _convert_old_type(ctx: Context, inp: typing.TypeBase, location: str) -> TypeVar:
result = ctx.new_var()
if isinstance(inp, typing.TypeUInt8):
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False))
result.add_location('u8')
result.add_location(location)
return result
if isinstance(inp, typing.TypeUInt32):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False))
result.add_location('u32')
result.add_location(location)
return result
if isinstance(inp, typing.TypeUInt64):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False))
result.add_location('u64')
result.add_location(location)
return result
if isinstance(inp, typing.TypeInt32):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True))
result.add_location('i32')
result.add_location(location)
return result
if isinstance(inp, typing.TypeInt64):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True))
result.add_location('i64')
result.add_location(location)
return result
raise NotImplementedError(_convert_old_type, inp)

View File

@ -304,6 +304,21 @@ def testEntry(a: i32, b: i32) -> i32:
assert 1 == suite.run_code(10, 20).returned_value
assert 0 == suite.run_code(10, 10).returned_value
@pytest.mark.integration_test
def test_call_no_args():
code_py = """
def helper() -> i32:
return 19
@exported
def testEntry() -> i32:
return helper()
"""
result = Suite(code_py).run_code()
assert 19 == result.returned_value
@pytest.mark.integration_test
def test_call_pre_defined():
code_py = """

View File

@ -0,0 +1,31 @@
import pytest
from phasm.parser import phasm_parse
from phasm.typer import phasm_type
from phasm.exceptions import TypingError
@pytest.mark.integration_test
def test_constant_too_wide():
code_py = """
def func_const() -> u8:
return 0xFFF
"""
ast = phasm_parse(code_py)
with pytest.raises(TypingError, match='Other min bitwidth exceeds max bitwidth'):
phasm_type(ast)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', [32, 64])
def test_signed_mismatch(type_):
code_py = f"""
def func_const() -> u{type_}:
return 0
def func_call() -> i{type_}:
return func_const()
"""
ast = phasm_parse(code_py)
with pytest.raises(TypingError, match='Signed does not matchq'):
phasm_type(ast)