From 5da45e78c2d246f9606e6a53fa1ee4864b56504f Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Mon, 19 Sep 2022 13:50:20 +0200 Subject: [PATCH] More work on StaticArray Also naming fix, buildin => builtin. Removes the use of ConstantStaticArray, as this was context dependent --- phasm/codestyle.py | 8 +- phasm/compiler.py | 6 +- phasm/ourlang.py | 16 +- phasm/parser.py | 61 ++--- phasm/typer.py | 49 +++- phasm/typing.py | 209 +++++++++++++----- pylintrc | 2 +- tests/integration/runners.py | 6 +- tests/integration/test_code/__init__.py | 0 tests/integration/test_code/test_typing.py | 20 ++ .../integration/test_lang/test_primitives.py | 2 +- .../test_lang/test_static_array.py | 41 +++- 12 files changed, 297 insertions(+), 123 deletions(-) create mode 100644 tests/integration/test_code/__init__.py create mode 100644 tests/integration/test_code/test_typing.py diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 500b878..d57c1db 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -54,7 +54,7 @@ def expression(inp: ourlang.Expression) -> str: # could not fit in the given float type return str(inp.value) - if isinstance(inp, (ourlang.ConstantTuple, ourlang.ConstantStaticArray, )): + if isinstance(inp, ourlang.ConstantTuple): return '(' + ', '.join( expression(x) for x in inp.value @@ -65,8 +65,8 @@ def expression(inp: ourlang.Expression) -> str: if isinstance(inp, ourlang.UnaryOp): if ( - inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS - or inp.operator in ourlang.WEBASSEMBLY_BUILDIN_BYTES_OPS): + inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS + or inp.operator in ourlang.WEBASSEMBLY_BUILTIN_BYTES_OPS): return f'{inp.operator}({expression(inp.right)})' if inp.operator == 'cast': @@ -186,7 +186,7 @@ def module(inp: ourlang.Module) -> str: for func in inp.functions.values(): if func.lineno < 0: - # Buildin (-2) or auto generated (-1) + # Builtin (-2) or auto generated (-1) continue if result: diff --git a/phasm/compiler.py b/phasm/compiler.py index 826f21e..d17c197 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -247,11 +247,11 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: mtyp = typing.simplify(inp.type_var) if mtyp == 'f32': - if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: + if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS: wgn.add_statement(f'f32.{inp.operator}') return if mtyp == 'f64': - if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: + if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS: wgn.add_statement(f'f64.{inp.operator}') return @@ -608,7 +608,7 @@ def module_data(inp: ourlang.ModuleData) -> bytes: data_list.append(module_data_f64(constant.value)) continue - raise NotImplementedError(constant, mtyp) + raise NotImplementedError(constant, constant.type_var, mtyp) block_data = b''.join(data_list) diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 733476f..b0e605d 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -7,8 +7,8 @@ import enum from typing_extensions import Final -WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', ) -WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', ) +WEBASSEMBLY_BUILTIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', ) +WEBASSEMBLY_BUILTIN_BYTES_OPS: Final = ('len', ) from .typing import ( TypeStruct, @@ -57,18 +57,6 @@ class ConstantTuple(Constant): super().__init__() self.value = value -class ConstantStaticArray(Constant): - """ - A StaticArray constant value expression within a statement - """ - __slots__ = ('value', ) - - value: List[ConstantPrimitive] - - def __init__(self, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays? - super().__init__() - self.value = value - class VariableReference(Expression): """ An variable reference expression within a statement diff --git a/phasm/parser.py b/phasm/parser.py index f6570d1..51019be 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -6,24 +6,22 @@ from typing import Any, Dict, NoReturn, Union import ast from .typing import ( + BUILTIN_TYPES, + TypeStruct, TypeStructMember, - TypeTuple, - TypeTupleMember, - TypeStaticArray, - TypeStaticArrayMember, ) from .exceptions import StaticError from .ourlang import ( - WEBASSEMBLY_BUILDIN_FLOAT_OPS, + WEBASSEMBLY_BUILTIN_FLOAT_OPS, Module, ModuleDataBlock, Function, Expression, BinaryOp, - ConstantPrimitive, ConstantTuple, ConstantStaticArray, + ConstantPrimitive, ConstantTuple, FunctionCall, Subscript, # StructConstructor, TupleConstructor, @@ -482,7 +480,7 @@ class OurVisitor: # struct_constructor = StructConstructor(struct) # # func = module.functions[struct_constructor.name] - if node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: + if node.func.id in WEBASSEMBLY_BUILTIN_FLOAT_OPS: if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') @@ -686,7 +684,7 @@ class OurVisitor: if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if node.id in ('u8', 'u32', 'u64', 'i32', 'i64', 'f32', 'f64'): # FIXME: Source this list somewhere + if node.id in BUILTIN_TYPES: return node.id raise NotImplementedError('TODO: Broken after type system') @@ -697,40 +695,21 @@ class OurVisitor: _raise_static_error(node, f'Unrecognized type {node.id}') if isinstance(node, ast.Subscript): - raise NotImplementedError('TODO: Broken after new type system') + if not isinstance(node.value, ast.Name): + _raise_static_error(node, 'Must be name') + if not isinstance(node.slice, ast.Index): + _raise_static_error(node, 'Must subscript using an index') + if not isinstance(node.slice.value, ast.Constant): + _raise_static_error(node, 'Must subscript using a constant index') + if not isinstance(node.slice.value.value, int): + _raise_static_error(node, 'Must subscript using a constant integer index') + if not isinstance(node.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') - # if not isinstance(node.value, ast.Name): - # _raise_static_error(node, 'Must be name') - # if not isinstance(node.slice, ast.Index): - # _raise_static_error(node, 'Must subscript using an index') - # if not isinstance(node.slice.value, ast.Constant): - # _raise_static_error(node, 'Must subscript using a constant index') - # if not isinstance(node.slice.value.value, int): - # _raise_static_error(node, 'Must subscript using a constant integer index') - # if not isinstance(node.ctx, ast.Load): - # _raise_static_error(node, 'Must be load context') - # - # if node.value.id in module.types: - # member_type = module.types[node.value.id] - # else: - # _raise_static_error(node, f'Unrecognized type {node.value.id}') - # - # type_static_array = TypeStaticArray(member_type) - # - # offset = 0 - # - # for idx in range(node.slice.value.value): - # static_array_member = TypeStaticArrayMember(idx, offset) - # - # type_static_array.members.append(static_array_member) - # offset += member_type.alloc_size() - # - # key = f'{node.value.id}[{node.slice.value.value}]' - # - # if key not in module.types: - # module.types[key] = type_static_array - # - # return module.types[key] + if node.value.id not in BUILTIN_TYPES: # FIXME: Tuple of tuples? + _raise_static_error(node, f'Unrecognized type {node.value.id}') + + return f'{node.value.id}[{node.slice.value.value}]' if isinstance(node, ast.Tuple): raise NotImplementedError('TODO: Broken after new type system') diff --git a/phasm/typer.py b/phasm/typer.py index ee356eb..d18aa1c 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -3,7 +3,13 @@ Type checks and enriches the given ast """ from . import ourlang -from .typing import Context, TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeVar, from_str +from .exceptions import TypingError +from .typing import ( + Context, + TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeConstraintSubscript, + TypeVar, + from_str, +) def phasm_type(inp: ourlang.Module) -> None: module(inp) @@ -55,6 +61,19 @@ def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar: raise NotImplementedError(constant, inp, inp.value) + if isinstance(inp, ourlang.ConstantTuple): + result = ctx.new_var() + + result.add_constraint(TypeConstraintSubscript(members=( + constant(ctx, x) + for x in inp.value + ))) + result.add_location(str(inp.value)) + + inp.type_var = result + + return result + raise NotImplementedError(constant, inp) def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': @@ -63,6 +82,8 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': if isinstance(inp, ourlang.VariableReference): assert inp.variable.type_var is not None + + inp.type_var = inp.variable.type_var return inp.variable.type_var if isinstance(inp, ourlang.UnaryOp): @@ -112,6 +133,32 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': return inp.function.returns_type_var + if isinstance(inp, ourlang.Subscript): + if not isinstance(inp.index, ourlang.ConstantPrimitive): + raise NotImplementedError(expression, inp, inp.index) + if not isinstance(inp.index.value, int): + raise NotImplementedError(expression, inp, inp.index.value) + + expression(ctx, inp.varref) + assert inp.varref.type_var is not None + + try: + # TODO: I'd much rather resolve this using the narrow functions + tc_subs = ctx.var_constraints[inp.varref.type_var.ctx_id][TypeConstraintSubscript] + except KeyError: + raise TypingError(f'Type cannot be subscripted: {inp.varref.type_var}') from None + + assert isinstance(tc_subs, TypeConstraintSubscript) # type hint + + try: + # TODO: I'd much rather resolve this using the narrow functions + member = tc_subs.members[inp.index.value] + except IndexError: + raise TypingError(f'Type cannot be subscripted with index {inp.index.value}: {inp.varref.type_var}') from None + + inp.type_var = member + return member + raise NotImplementedError(expression, inp) def function(ctx: Context, inp: ourlang.Function) -> None: diff --git a/phasm/typing.py b/phasm/typing.py index 0bebf8d..25f85a7 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -1,9 +1,10 @@ """ The phasm type system """ -from typing import Dict, Iterable, Optional, List, Set, Type +from typing import Callable, Dict, Iterable, Optional, List, Set, Type import enum +import re from .exceptions import TypingError @@ -151,7 +152,7 @@ class TypeConstraintBase: """ Base class for classes implementing a contraint on a type """ - def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase': + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBase': raise NotImplementedError('narrow', self, other) class TypeConstraintPrimitive(TypeConstraintBase): @@ -172,7 +173,7 @@ class TypeConstraintPrimitive(TypeConstraintBase): def __init__(self, primitive: Primitive) -> None: self.primitive = primitive - def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive': + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive': if not isinstance(other, TypeConstraintPrimitive): raise Exception('Invalid comparison') @@ -196,7 +197,7 @@ class TypeConstraintSigned(TypeConstraintBase): def __init__(self, signed: Optional[bool]) -> None: self.signed = signed - def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintSigned': + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintSigned': if not isinstance(other, TypeConstraintSigned): raise Exception('Invalid comparison') @@ -239,7 +240,7 @@ class TypeConstraintBitWidth(TypeConstraintBase): if x <= maxb } - def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth': + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth': if not isinstance(other, TypeConstraintBitWidth): raise Exception('Invalid comparison') @@ -251,7 +252,60 @@ class TypeConstraintBitWidth(TypeConstraintBase): return TypeConstraintBitWidth(oneof=new_oneof) def __repr__(self) -> str: - return 'BitWidth=oneof(' + ','.join(map(str, sorted(self.oneof))) + ')' + result = 'BitWidth=' + + items = list(sorted(self.oneof)) + if not items: + return result + + while items: + itm = items.pop(0) + result += str(itm) + + cnt = 0 + while cnt < len(items) and items[cnt] == itm + cnt + 1: + cnt += 1 + + if cnt == 1: + result += ',' + str(items[0]) + elif cnt > 1: + result += '..' + str(items[cnt - 1]) + + items = items[cnt:] + if items: + result += ',' + + return result + +class TypeConstraintSubscript(TypeConstraintBase): + """ + Contraint on allowing a type to be subscripted + """ + __slots__ = ('members', ) + + members: List['TypeVar'] + + def __init__(self, *, members: Iterable['TypeVar']) -> None: + self.members = list(members) + + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintSubscript': + if not isinstance(other, TypeConstraintSubscript): + raise Exception('Invalid comparison') + + if len(self.members) != len(other.members): + raise TypingNarrowProtoError('Member count does not match') + + newmembers = [] + for smb, omb in zip(self.members, other.members): + nmb = ctx.new_var() + ctx.unify(nmb, smb) + ctx.unify(nmb, omb) + newmembers.append(nmb) + + return TypeConstraintSubscript(members=newmembers) + + def __repr__(self) -> str: + return 'Subscript=(' + ','.join(map(repr, self.members)) + ')' class TypeVar: """ @@ -271,7 +325,7 @@ class TypeVar: csts = self.ctx.var_constraints[self.ctx_id] if newconst.__class__ in csts: - csts[newconst.__class__] = csts[newconst.__class__].narrow(newconst) + csts[newconst.__class__] = csts[newconst.__class__].narrow(self.ctx, newconst) else: csts[newconst.__class__] = newconst @@ -413,6 +467,93 @@ def simplify(inp: TypeVar) -> Optional[str]: return None +def make_u8(ctx: Context, location: str) -> TypeVar: + """ + Makes a u8 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location(location) + return result + +def make_u32(ctx: Context, location: str) -> TypeVar: + """ + Makes a u32 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location(location) + return result + +def make_u64(ctx: Context, location: str) -> TypeVar: + """ + Makes a u64 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location(location) + return result + +def make_i32(ctx: Context, location: str) -> TypeVar: + """ + Makes a i32 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_constraint(TypeConstraintSigned(True)) + result.add_location(location) + return result + +def make_i64(ctx: Context, location: str) -> TypeVar: + """ + Makes a i64 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) + result.add_constraint(TypeConstraintSigned(True)) + result.add_location(location) + return result + +def make_f32(ctx: Context, location: str) -> TypeVar: + """ + Makes a f32 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_location(location) + return result + +def make_f64(ctx: Context, location: str) -> TypeVar: + """ + Makes a f64 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_location(location) + return result + +BUILTIN_TYPES: Dict[str, Callable[[Context, str], TypeVar]] = { + 'u8': make_u8, + 'u32': make_u32, + 'u64': make_u64, + 'i32': make_i32, + 'i64': make_i64, + 'f32': make_f32, + 'f64': make_f64, +} + +TYPE_MATCH_STATIC_ARRAY = re.compile(r'^([uif][0-9]+)\[([0-9]+)\]') + def from_str(ctx: Context, inp: str, location: str) -> TypeVar: """ Creates a new TypeVar from the string @@ -425,53 +566,21 @@ def from_str(ctx: Context, inp: str, location: str) -> TypeVar: This could be conidered part of parsing. Though that would give trouble with the context creation. """ - result = ctx.new_var() + if inp in BUILTIN_TYPES: + return BUILTIN_TYPES[inp](ctx, location) - if inp == 'u8': - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) - result.add_constraint(TypeConstraintSigned(False)) - result.add_location(location) - return result + match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp) + if match: + result = ctx.new_var() - if inp == 'u32': - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) - result.add_constraint(TypeConstraintSigned(False)) + result.add_constraint(TypeConstraintSubscript(members=( + # Make copies so they don't get entangled + # with each other. + from_str(ctx, match[1], match[1]) + for _ in range(int(match[2])) + ))) result.add_location(location) - return result - if inp == 'u64': - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) - result.add_constraint(TypeConstraintSigned(False)) - result.add_location(location) - return result - - if inp == 'i32': - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) - result.add_constraint(TypeConstraintSigned(True)) - result.add_location(location) - return result - - if inp == 'i64': - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) - result.add_constraint(TypeConstraintSigned(True)) - result.add_location(location) - return result - - if inp == 'f32': - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) - result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) - result.add_location(location) - return result - - if inp == 'f64': - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) - result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) - result.add_location(location) return result raise NotImplementedError(from_str, inp) diff --git a/pylintrc b/pylintrc index 82948bb..f872f8d 100644 --- a/pylintrc +++ b/pylintrc @@ -7,4 +7,4 @@ max-line-length=180 good-names=g [tests] -disable=C0116, +disable=C0116,R0201 diff --git a/tests/integration/runners.py b/tests/integration/runners.py index 005d44e..77cd3f5 100644 --- a/tests/integration/runners.py +++ b/tests/integration/runners.py @@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, Iterable, Optional, TextIO import ctypes import io -import warnings import pywasm.binary import wasm3 @@ -42,10 +41,7 @@ class RunnerBase: Parses the Phasm code into an AST """ self.phasm_ast = phasm_parse(self.phasm_code) - try: - phasm_type(self.phasm_ast) - except NotImplementedError as exc: - warnings.warn(f'phasm_type throws an NotImplementedError on this test: {exc}') + phasm_type(self.phasm_ast) def compile_ast(self) -> None: """ diff --git a/tests/integration/test_code/__init__.py b/tests/integration/test_code/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_code/test_typing.py b/tests/integration/test_code/test_typing.py new file mode 100644 index 0000000..7be1140 --- /dev/null +++ b/tests/integration/test_code/test_typing.py @@ -0,0 +1,20 @@ +import pytest + +from phasm import typing as sut + +class TestTypeConstraintBitWidth: + @pytest.mark.parametrize('oneof,exp', [ + (set(), '', ), + ({1}, '1', ), + ({1,2}, '1,2', ), + ({1,2,3}, '1..3', ), + ({1,2,3,4}, '1..4', ), + + ({1,3}, '1,3', ), + ({1,4}, '1,4', ), + + ({1,2,3,4,6,7,8,9}, '1..4,6..9', ), + ]) + def test_repr(self, oneof, exp): + mut_self = sut.TypeConstraintBitWidth(oneof=oneof) + assert ('BitWidth=' + exp) == repr(mut_self) diff --git a/tests/integration/test_lang/test_primitives.py b/tests/integration/test_lang/test_primitives.py index 441dcdb..5736f0b 100644 --- a/tests/integration/test_lang/test_primitives.py +++ b/tests/integration/test_lang/test_primitives.py @@ -219,7 +219,7 @@ def testEntry() -> {type_}: @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['f32', 'f64']) -def test_buildins_sqrt(type_): +def test_builtins_sqrt(type_): code_py = f""" @exported def testEntry() -> {type_}: diff --git a/tests/integration/test_lang/test_static_array.py b/tests/integration/test_lang/test_static_array.py index 68bfdd4..6ea5985 100644 --- a/tests/integration/test_lang/test_static_array.py +++ b/tests/integration/test_lang/test_static_array.py @@ -1,12 +1,12 @@ import pytest -from phasm.exceptions import StaticError +from phasm.exceptions import StaticError, TypingError -from ..constants import COMPLETE_PRIMITIVE_TYPES, TYPE_MAP +from ..constants import ALL_INT_TYPES, COMPLETE_PRIMITIVE_TYPES, TYPE_MAP from ..helpers import Suite @pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) +@pytest.mark.parametrize('type_', ALL_INT_TYPES) def test_module_constant(type_): code_py = f""" CONSTANT: {type_}[3] = (24, 57, 80, ) @@ -59,6 +59,41 @@ def helper(array: {type_}[3]) -> {type_}: assert 161 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) +@pytest.mark.integration_test +def test_module_constant_type_mismatch_bitwidth(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 280, ) +""" + + with pytest.raises(TypingError, match='u8.*280'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_module_constant_type_mismatch_not_subscriptable(): + code_py = """ +CONSTANT: u8 = 24 + +@exported +def testEntry() -> u8: + return CONSTANT[0] +""" + + with pytest.raises(TypingError, match='Type cannot be subscripted:'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_module_constant_type_mismatch_index_out_of_range(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 80, ) + +@exported +def testEntry() -> u8: + return CONSTANT[3] +""" + + with pytest.raises(TypingError, match='Type cannot be subscripted with index 3:'): + Suite(code_py).run_code() + @pytest.mark.integration_test def test_static_array_constant_too_few_values(): code_py = """