diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 6ae9518..efc1a66 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -368,7 +368,7 @@ class Function: """ A function processes input and produces output """ - __slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'posonlyargs', ) + __slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'returns_type_var', 'posonlyargs', ) name: str lineno: int @@ -376,6 +376,7 @@ class Function: imported: bool statements: List[Statement] returns: TypeBase + returns_type_var: Optional[TypeVar] posonlyargs: List[FunctionParam] def __init__(self, name: str, lineno: int) -> None: @@ -385,6 +386,7 @@ class Function: self.imported = False self.statements = [] self.returns = TypeNone() + self.returns_type_var = None self.posonlyargs = [] class StructConstructor(Function): diff --git a/phasm/parser.py b/phasm/parser.py index 2d7a158..78486ef 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -564,8 +564,8 @@ class OurVisitor: func = module.functions[node.func.id] - if func.returns != exp_type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') + # if func.returns != exp_type: + # _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') if len(func.posonlyargs) != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given') @@ -700,47 +700,47 @@ class OurVisitor: _not_implemented(node.kind is None, 'Constant.kind') if isinstance(exp_type, TypeUInt8): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < 0 or node.value > 255: - _raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}') + # if not isinstance(node.value, int): + # _raise_static_error(node, 'Expected integer value') + # + # if node.value < 0 or node.value > 255: + # _raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}') return ConstantUInt8(exp_type, node.value) if isinstance(exp_type, TypeUInt32): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < 0 or node.value > 4294967295: - _raise_static_error(node, 'Integer value out of range') + # if not isinstance(node.value, int): + # _raise_static_error(node, 'Expected integer value') + # + # if node.value < 0 or node.value > 4294967295: + # _raise_static_error(node, 'Integer value out of range') return ConstantUInt32(exp_type, node.value) if isinstance(exp_type, TypeUInt64): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < 0 or node.value > 18446744073709551615: - _raise_static_error(node, 'Integer value out of range') + # if not isinstance(node.value, int): + # _raise_static_error(node, 'Expected integer value') + # + # if node.value < 0 or node.value > 18446744073709551615: + # _raise_static_error(node, 'Integer value out of range') return ConstantUInt64(exp_type, node.value) if isinstance(exp_type, TypeInt32): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < -2147483648 or node.value > 2147483647: - _raise_static_error(node, 'Integer value out of range') + # if not isinstance(node.value, int): + # _raise_static_error(node, 'Expected integer value') + # + # if node.value < -2147483648 or node.value > 2147483647: + # _raise_static_error(node, 'Integer value out of range') return ConstantInt32(exp_type, node.value) if isinstance(exp_type, TypeInt64): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < -9223372036854775808 or node.value > 9223372036854775807: - _raise_static_error(node, 'Integer value out of range') + # if not isinstance(node.value, int): + # _raise_static_error(node, 'Expected integer value') + # + # if node.value < -9223372036854775808 or node.value > 9223372036854775807: + # _raise_static_error(node, 'Integer value out of range') return ConstantInt64(exp_type, node.value) diff --git a/phasm/typer.py b/phasm/typer.py index c0677fb..97c5e44 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -9,22 +9,23 @@ def phasm_type(inp: ourlang.Module) -> None: module(inp) def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar': - if getattr(inp, 'value', int): + value = getattr(inp, 'value', None) + if isinstance(value, int): result = ctx.new_var() # Need at least this many bits to store this constant value - result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2)) # type: ignore + result.add_constraint(TypeConstraintBitWidth(minb=len(bin(value)) - 2)) # Don't dictate anything about signedness - you can use a signed # constant in an unsigned variable if the bits fit result.add_constraint(TypeConstraintSigned(None)) - result.add_location(str(inp.value)) # type: ignore + result.add_location(str(value)) inp.type_var = result return result - raise NotImplementedError(constant, inp) + raise NotImplementedError(constant, inp, value) def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': if isinstance(inp, ourlang.Constant): @@ -43,58 +44,68 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': ctx.unify(left, right) return left + if isinstance(inp, ourlang.FunctionCall): + assert inp.function.returns_type_var is not None + if inp.function.posonlyargs: + raise NotImplementedError + + return inp.function.returns_type_var + raise NotImplementedError(expression, inp) def function(ctx: 'Context', inp: ourlang.Function) -> None: - for param in inp.posonlyargs: - param.type_var = _convert_old_type(ctx, param.type) - if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn): raise NotImplementedError('Functions with not just a return statement') typ = expression(ctx, inp.statements[0].value) - ctx.unify(_convert_old_type(ctx, inp.returns), typ) + assert inp.returns_type_var is not None + ctx.unify(inp.returns_type_var, typ) return def module(inp: ourlang.Module) -> None: ctx = Context() + for func in inp.functions.values(): + func.returns_type_var = _convert_old_type(ctx, func.returns, f'{func.name}.(returns)') + for param in func.posonlyargs: + param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}') + for func in inp.functions.values(): function(ctx, func) from . import typing -def _convert_old_type(ctx: Context, inp: typing.TypeBase) -> TypeVar: +def _convert_old_type(ctx: Context, inp: typing.TypeBase, location: str) -> TypeVar: result = ctx.new_var() if isinstance(inp, typing.TypeUInt8): result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) result.add_constraint(TypeConstraintSigned(False)) - result.add_location('u8') + result.add_location(location) return result if isinstance(inp, typing.TypeUInt32): result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintSigned(False)) - result.add_location('u32') + result.add_location(location) return result if isinstance(inp, typing.TypeUInt64): - result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintSigned(False)) - result.add_location('u64') + result.add_location(location) return result if isinstance(inp, typing.TypeInt32): result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintSigned(True)) - result.add_location('i32') + result.add_location(location) return result if isinstance(inp, typing.TypeInt64): - result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintSigned(True)) - result.add_location('i64') + result.add_location(location) return result raise NotImplementedError(_convert_old_type, inp) diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index f0c2993..449f34c 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -304,6 +304,21 @@ def testEntry(a: i32, b: i32) -> i32: assert 1 == suite.run_code(10, 20).returned_value assert 0 == suite.run_code(10, 10).returned_value +@pytest.mark.integration_test +def test_call_no_args(): + code_py = """ +def helper() -> i32: + return 19 + +@exported +def testEntry() -> i32: + return helper() +""" + + result = Suite(code_py).run_code() + + assert 19 == result.returned_value + @pytest.mark.integration_test def test_call_pre_defined(): code_py = """ diff --git a/tests/integration/test_type_checks.py b/tests/integration/test_type_checks.py new file mode 100644 index 0000000..afc1d1e --- /dev/null +++ b/tests/integration/test_type_checks.py @@ -0,0 +1,31 @@ +import pytest + +from phasm.parser import phasm_parse +from phasm.typer import phasm_type +from phasm.exceptions import TypingError + +@pytest.mark.integration_test +def test_constant_too_wide(): + code_py = """ +def func_const() -> u8: + return 0xFFF +""" + + ast = phasm_parse(code_py) + with pytest.raises(TypingError, match='Other min bitwidth exceeds max bitwidth'): + phasm_type(ast) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', [32, 64]) +def test_signed_mismatch(type_): + code_py = f""" +def func_const() -> u{type_}: + return 0 + +def func_call() -> i{type_}: + return func_const() +""" + + ast = phasm_parse(code_py) + with pytest.raises(TypingError, match='Signed does not matchq'): + phasm_type(ast)