Type compare, code cleanup, extra test

This commit is contained in:
Johan B.W. de Vries 2022-12-24 19:22:03 +01:00
parent 6e0c554cf2
commit e456f55bb0
6 changed files with 116 additions and 239 deletions

View File

@ -16,18 +16,6 @@ def phasm_render(inp: ourlang.Module) -> str:
return module(inp) return module(inp)
Statements = Generator[str, None, None] Statements = Generator[str, None, None]
#
# def type_var(inp: Optional[typing.TypeVar]) -> str:
# """
# Render: type's name
# """
# assert inp is not None, typing.ASSERTION_ERROR
#
# mtyp = typing.simplify(inp)
# if mtyp is None:
# raise NotImplementedError(f'Rendering type {inp}')
#
# return mtyp
def type3(inp: Type3OrPlaceholder) -> str: def type3(inp: Type3OrPlaceholder) -> str:
""" """
@ -36,13 +24,13 @@ def type3(inp: Type3OrPlaceholder) -> str:
assert isinstance(inp, Type3), TYPE3_ASSERTION_ERROR assert isinstance(inp, Type3), TYPE3_ASSERTION_ERROR
if isinstance(inp, type3types.AppliedType3): if isinstance(inp, type3types.AppliedType3):
if inp.base is type3types.tuple: if inp.base == type3types.tuple:
return '(' + ', '.join( return '(' + ', '.join(
type3(x) type3(x)
for x in inp.args for x in inp.args
) + ', )' ) + ', )'
if inp.base is type3types.static_array: if inp.base == type3types.static_array:
assert 1 == len(inp.args) assert 1 == len(inp.args)
assert isinstance(inp.args[0], Type3), TYPE3_ASSERTION_ERROR assert isinstance(inp.args[0], Type3), TYPE3_ASSERTION_ERROR

View File

@ -40,37 +40,41 @@ def type3(inp: type3types.Type3OrPlaceholder) -> wasm.WasmType:
""" """
assert isinstance(inp, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR assert isinstance(inp, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR
if inp is type3types.u8: if inp == type3types.u8:
# WebAssembly has only support for 32 and 64 bits # WebAssembly has only support for 32 and 64 bits
# So we need to store more memory per byte # So we need to store more memory per byte
return wasm.WasmTypeInt32() return wasm.WasmTypeInt32()
if inp is type3types.u32: if inp == type3types.u32:
return wasm.WasmTypeInt32() return wasm.WasmTypeInt32()
if inp is type3types.u64: if inp == type3types.u64:
return wasm.WasmTypeInt64() return wasm.WasmTypeInt64()
if inp is type3types.i32: if inp == type3types.i32:
return wasm.WasmTypeInt32() return wasm.WasmTypeInt32()
if inp is type3types.i64: if inp == type3types.i64:
return wasm.WasmTypeInt64() return wasm.WasmTypeInt64()
if inp is type3types.f32: if inp == type3types.f32:
return wasm.WasmTypeFloat32() return wasm.WasmTypeFloat32()
if inp is type3types.f64: if inp == type3types.f64:
return wasm.WasmTypeFloat64() return wasm.WasmTypeFloat64()
if inp is type3types.bytes: if inp == type3types.bytes:
# bytes are passed as pointer # bytes are passed as pointer
# And pointers are i32 # And pointers are i32
return wasm.WasmTypeInt32() return wasm.WasmTypeInt32()
if isinstance(inp, type3types.StructType3): if isinstance(inp, type3types.StructType3):
# Structs and tuples are passed as pointer # Structs are passed as pointer, which are i32
# And pointers are i32 return wasm.WasmTypeInt32()
if isinstance(inp, type3types.AppliedType3):
if inp.base == type3types.static_array:
# Static Arrays are passed as pointer, which are i32
return wasm.WasmTypeInt32() return wasm.WasmTypeInt32()
raise NotImplementedError(type3, inp) raise NotImplementedError(type3, inp)
@ -150,28 +154,28 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
if isinstance(inp, ourlang.ConstantPrimitive): if isinstance(inp, ourlang.ConstantPrimitive):
assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR
if inp.type3 is type3types.u8: if inp.type3 == type3types.u8:
# No native u8 type - treat as i32, with caution # No native u8 type - treat as i32, with caution
assert isinstance(inp.value, int) assert isinstance(inp.value, int)
wgn.i32.const(inp.value) wgn.i32.const(inp.value)
return return
if inp.type3 is type3types.i32 or inp.type3 is type3types.u32: if inp.type3 in (type3types.i32, type3types.u32, ):
assert isinstance(inp.value, int) assert isinstance(inp.value, int)
wgn.i32.const(inp.value) wgn.i32.const(inp.value)
return return
if inp.type3 is type3types.i64 or inp.type3 is type3types.u64: if inp.type3 in (type3types.i64, type3types.u64, ):
assert isinstance(inp.value, int) assert isinstance(inp.value, int)
wgn.i64.const(inp.value) wgn.i64.const(inp.value)
return return
if inp.type3 is type3types.f32: if inp.type3 == type3types.f32:
assert isinstance(inp.value, float) assert isinstance(inp.value, float)
wgn.f32.const(inp.value) wgn.f32.const(inp.value)
return return
if inp.type3 is type3types.f64: if inp.type3 == type3types.f64:
assert isinstance(inp.value, float) assert isinstance(inp.value, float)
wgn.f64.const(inp.value) wgn.f64.const(inp.value)
return return
@ -193,7 +197,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
return return
if isinstance(inp.type3, type3types.AppliedType3): if isinstance(inp.type3, type3types.AppliedType3):
if inp.type3.base is type3types.static_array: if inp.type3.base == type3types.static_array:
assert inp.variable.data_block is not None, 'Static arrays must be memory stored' assert inp.variable.data_block is not None, 'Static arrays must be memory stored'
assert inp.variable.data_block.address is not None, 'Value not allocated' assert inp.variable.data_block.address is not None, 'Value not allocated'
wgn.i32.const(inp.variable.data_block.address) wgn.i32.const(inp.variable.data_block.address)
@ -208,12 +212,6 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
# return # return
# #
# if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
# assert inp.variable.data_block is not None, 'Combined values are memory stored'
# assert inp.variable.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.variable.data_block.address)
# return
assert inp.variable.data_block is None, 'Primitives are not memory stored' assert inp.variable.data_block is None, 'Primitives are not memory stored'
expression(wgn, inp.variable.constant) expression(wgn, inp.variable.constant)
@ -228,46 +226,46 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR
# FIXME: Re-implement build-in operators # FIXME: Re-implement build-in operators
if inp.type3 is type3types.u8: if inp.type3 == type3types.u8:
if operator := U8_OPERATOR_MAP.get(inp.operator, None): if operator := U8_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if inp.type3 is type3types.u32: if inp.type3 == type3types.u32:
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if operator := U32_OPERATOR_MAP.get(inp.operator, None): if operator := U32_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if inp.type3 is type3types.u64: if inp.type3 == type3types.u64:
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}') wgn.add_statement(f'i64.{operator}')
return return
if operator := U64_OPERATOR_MAP.get(inp.operator, None): if operator := U64_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}') wgn.add_statement(f'i64.{operator}')
return return
if inp.type3 is type3types.i32: if inp.type3 == type3types.i32:
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if operator := I32_OPERATOR_MAP.get(inp.operator, None): if operator := I32_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if inp.type3 is type3types.i64: if inp.type3 == type3types.i64:
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}') wgn.add_statement(f'i64.{operator}')
return return
if operator := I64_OPERATOR_MAP.get(inp.operator, None): if operator := I64_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}') wgn.add_statement(f'i64.{operator}')
return return
if inp.type3 is type3types.f32: if inp.type3 == type3types.f32:
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'f32.{operator}') wgn.add_statement(f'f32.{operator}')
return return
if operator := F32_OPERATOR_MAP.get(inp.operator, None): if operator := F32_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'f32.{operator}') wgn.add_statement(f'f32.{operator}')
return return
if inp.type3 is type3types.f64: if inp.type3 == type3types.f64:
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'f64.{operator}') wgn.add_statement(f'f64.{operator}')
return return
@ -282,18 +280,18 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR
if inp.type3 is type3types.f32: if inp.type3 == type3types.f32:
if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS: if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS:
wgn.add_statement(f'f32.{inp.operator}') wgn.add_statement(f'f32.{inp.operator}')
return return
if inp.type3 is type3types.f64: if inp.type3 == type3types.f64:
if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS: if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS:
wgn.add_statement(f'f64.{inp.operator}') wgn.add_statement(f'f64.{inp.operator}')
return return
if inp.type3 is type3types.u32: if inp.type3 == type3types.u32:
if inp.operator == 'len': if inp.operator == 'len':
if inp.right.type3 is type3types.bytes: if inp.right.type3 == type3types.bytes:
wgn.i32.load() wgn.i32.load()
return return
@ -316,74 +314,32 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
assert isinstance(inp.varref.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR assert isinstance(inp.varref.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR
if isinstance(inp.varref.type3, type3types.AppliedType3): if isinstance(inp.varref.type3, type3types.AppliedType3):
if inp.varref.type3.base is type3types.static_array: if inp.varref.type3.base == type3types.static_array:
assert 1 == len(inp.varref.type3.args) assert 1 == len(inp.varref.type3.args)
assert isinstance(inp.varref.type3.args[0], type3types.Type3) el_type = inp.varref.type3.args[0]
assert isinstance(el_type, type3types.Type3)
# OPTIMIZE: If index is a constant, we can use offset instead of multiply
# FIXME: Out of bounds check
expression(wgn, inp.varref) expression(wgn, inp.varref)
expression(wgn, inp.index) expression(wgn, inp.index)
wgn.i32.const(_calculate_alloc_size(inp.varref.type3.args[0])) wgn.i32.const(_calculate_alloc_size(el_type))
wgn.i32.mul() wgn.i32.mul()
wgn.i32.add() wgn.i32.add()
mtyp = LOAD_STORE_TYPE_MAP.get(el_type.name)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, el_type)
wgn.add_statement(f'{mtyp}.load')
return return
#
# assert inp.varref.type_var is not None, typing.ASSERTION_ERROR
# tc_type = inp.varref.type_var.get_type()
# if tc_type is None:
# raise NotImplementedError(expression, inp, inp.varref.type_var)
# if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
# if not isinstance(inp.index, ourlang.ConstantPrimitive):
# raise NotImplementedError(expression, inp, inp.index)
# if not isinstance(inp.index.value, int):
# raise NotImplementedError(expression, inp, inp.index.value)
#
# assert inp.type_var is not None, typing.ASSERTION_ERROR
# mtyp = typing.simplify(inp.type_var)
# if mtyp is None:
# raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp)
#
# if mtyp == 'u8':
# # u8 operations are done using i32, since WASM does not have u8 operations
# mtyp = 'i32'
# elif mtyp == 'u32':
# # u32 operations are done using i32, using _u operations
# mtyp = 'i32'
# elif mtyp == 'u64':
# # u64 operations are done using i64, using _u operations
# mtyp = 'i64'
#
# tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript)
# if tc_subs is None:
# raise NotImplementedError(expression, inp, inp.varref.type_var)
#
# assert 0 < len(tc_subs.members)
# tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth)
# if tc_bits is None or len(tc_bits.oneof) > 1:
# raise NotImplementedError(expression, inp, inp.varref.type_var)
#
# bitwidth = next(iter(tc_bits.oneof))
# if bitwidth % 8 != 0:
# raise NotImplementedError(expression, inp, inp.varref.type_var)
#
# expression(wgn, inp.varref)
# wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value))
# return
raise NotImplementedError(expression, inp, inp.varref.type3) raise NotImplementedError(expression, inp, inp.varref.type3)
# TODO: Broken after new type system
# if isinstance(inp, ourlang.AccessBytesIndex):
# if not isinstance(inp.type, typing.TypeUInt8):
# raise NotImplementedError(inp, inp.type)
#
# expression(wgn, inp.varref)
# expression(wgn, inp.index)
# wgn.call(stdlib_types.__subscript_bytes__)
# return
if isinstance(inp, ourlang.AccessStructMember): if isinstance(inp, ourlang.AccessStructMember):
mtyp = LOAD_STORE_TYPE_MAP.get(inp.struct_type3.members[inp.member].name) mtyp = LOAD_STORE_TYPE_MAP.get(inp.struct_type3.members[inp.member].name)
if mtyp is None: if mtyp is None:
@ -397,32 +353,6 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
))) )))
return return
# if isinstance(inp, ourlang.AccessTupleMember):
# mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__)
# if mtyp is None:
# # In the future might extend this by having structs or tuples
# # as members of struct or tuples
# raise NotImplementedError(expression, inp, inp.member)
#
# expression(wgn, inp.varref)
# wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
# return
#
# if isinstance(inp, ourlang.AccessStaticArrayMember):
# mtyp = LOAD_STORE_TYPE_MAP.get(inp.static_array.member_type.__class__)
# if mtyp is None:
# # In the future might extend this by having structs or tuples
# # as members of static arrays
# raise NotImplementedError(expression, inp, inp.member)
#
# expression(wgn, inp.varref)
# expression(wgn, inp.member)
# wgn.i32.const(inp.static_array.member_type.alloc_size())
# wgn.i32.mul()
# wgn.i32.add()
# wgn.add_statement(f'{mtyp}.load')
# return
if isinstance(inp, ourlang.Fold): if isinstance(inp, ourlang.Fold):
expression_fold(wgn, inp) expression_fold(wgn, inp)
return return
@ -662,37 +592,37 @@ def module_data(inp: ourlang.ModuleData) -> bytes:
for constant in block.data: for constant in block.data:
assert isinstance(constant.type3, type3types.Type3), (id(constant), type3types.TYPE3_ASSERTION_ERROR) assert isinstance(constant.type3, type3types.Type3), (id(constant), type3types.TYPE3_ASSERTION_ERROR)
if constant.type3 is type3types.u8: if constant.type3 == type3types.u8:
assert isinstance(constant.value, int) assert isinstance(constant.value, int)
data_list.append(module_data_u8(constant.value)) data_list.append(module_data_u8(constant.value))
continue continue
if constant.type3 is type3types.u32: if constant.type3 == type3types.u32:
assert isinstance(constant.value, int) assert isinstance(constant.value, int)
data_list.append(module_data_u32(constant.value)) data_list.append(module_data_u32(constant.value))
continue continue
if constant.type3 is type3types.u64: if constant.type3 == type3types.u64:
assert isinstance(constant.value, int) assert isinstance(constant.value, int)
data_list.append(module_data_u64(constant.value)) data_list.append(module_data_u64(constant.value))
continue continue
if constant.type3 is type3types.i32: if constant.type3 == type3types.i32:
assert isinstance(constant.value, int) assert isinstance(constant.value, int)
data_list.append(module_data_i32(constant.value)) data_list.append(module_data_i32(constant.value))
continue continue
if constant.type3 is type3types.i64: if constant.type3 == type3types.i64:
assert isinstance(constant.value, int) assert isinstance(constant.value, int)
data_list.append(module_data_i64(constant.value)) data_list.append(module_data_i64(constant.value))
continue continue
if constant.type3 is type3types.f32: if constant.type3 == type3types.f32:
assert isinstance(constant.value, float) assert isinstance(constant.value, float)
data_list.append(module_data_f32(constant.value)) data_list.append(module_data_f32(constant.value))
continue continue
if constant.type3 is type3types.f64: if constant.type3 == type3types.f64:
assert isinstance(constant.value, float) assert isinstance(constant.value, float)
data_list.append(module_data_f64(constant.value)) data_list.append(module_data_f64(constant.value))
continue continue
@ -797,13 +727,13 @@ def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstruc
wgn.local.get(tmp_var) wgn.local.get(tmp_var)
def _calculate_alloc_size(typ: Union[type3types.StructType3, type3types.Type3]) -> int: def _calculate_alloc_size(typ: Union[type3types.StructType3, type3types.Type3]) -> int:
if typ is type3types.u8: if typ == type3types.u8:
return 1 return 4 # FIXME: We allocate 4 bytes for every u8 since you load them into an i32
if typ is type3types.u32 or typ is type3types.i32 or typ is type3types.f32: if typ in (type3types.u32, type3types.i32, type3types.f32, ):
return 4 return 4
if typ is type3types.u64 or typ is type3types.i64 or typ is type3types.f64: if typ in (type3types.u64, type3types.i64, type3types.f64, ):
return 8 return 8
if isinstance(typ, type3types.StructType3): if isinstance(typ, type3types.StructType3):

View File

@ -633,63 +633,6 @@ class OurVisitor:
return Subscript(varref, slice_expr) return Subscript(varref, slice_expr)
# if isinstance(node_typ, TypeBytes):
# if isinstance(varref, ModuleConstantReference):
# raise NotImplementedError(f'{node} from module constant')
#
# return AccessBytesIndex(
# varref,
# slice_expr,
# )
#
# if isinstance(node_typ, TypeTuple):
# if not isinstance(slice_expr, ConstantPrimitive):
# _raise_static_error(node, 'Must subscript using a constant index')
#
# idx = slice_expr.value
#
# if not isinstance(idx, int):
# _raise_static_error(node, 'Must subscript using a constant integer index')
#
# if not (0 <= idx < len(node_typ.members)):
# _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}')
#
# tuple_member = node_typ.members[idx]
#
# if isinstance(varref, ModuleConstantReference):
# raise NotImplementedError(f'{node} from module constant')
#
# return AccessTupleMember(
# varref,
# tuple_member,
# )
#
# if isinstance(node_typ, TypeStaticArray):
# if not isinstance(slice_expr, ConstantPrimitive):
# return AccessStaticArrayMember(
# varref,
# node_typ,
# slice_expr,
# )
#
# idx = slice_expr.value
#
# if not isinstance(idx, int):
# _raise_static_error(node, 'Must subscript using an integer index')
#
# if not (0 <= idx < len(node_typ.members)):
# _raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}')
#
# static_array_member = node_typ.members[idx]
#
# return AccessStaticArrayMember(
# varref,
# node_typ,
# static_array_member,
# )
#
# _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}')
def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive: def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive:
del module del module

View File

@ -122,7 +122,7 @@ class SameTypeConstraint(ConstraintBase):
first_type = known_types[0] first_type = known_types[0]
for typ in known_types[1:]: for typ in known_types[1:]:
if typ is not first_type: if typ != first_type:
return Error(f'{typ:s} must be {first_type:s} instead') return Error(f'{typ:s} must be {first_type:s} instead')
if not placeholders: if not placeholders:
@ -265,7 +265,7 @@ class LiteralFitsConstraint(ConstraintBase):
res: NewConstraintList res: NewConstraintList
if isinstance(self.type3, types.AppliedType3): if isinstance(self.type3, types.AppliedType3):
if self.type3.base is types.tuple: if self.type3.base == types.tuple:
if not isinstance(self.literal, ourlang.ConstantTuple): if not isinstance(self.literal, ourlang.ConstantTuple):
return Error('Must be tuple') return Error('Must be tuple')
@ -285,7 +285,7 @@ class LiteralFitsConstraint(ConstraintBase):
return res return res
if self.type3.base is types.static_array: if self.type3.base == types.static_array:
if not isinstance(self.literal, ourlang.ConstantTuple): if not isinstance(self.literal, ourlang.ConstantTuple):
return Error('Must be tuple') return Error('Must be tuple')
@ -371,42 +371,17 @@ class CanBeSubscriptedConstraint(ConstraintBase):
self.type3 = smap[self.type3] self.type3 = smap[self.type3]
if isinstance(self.type3, types.AppliedType3): if isinstance(self.type3, types.AppliedType3):
if self.type3.base is types.static_array: if self.type3.base == types.static_array:
return [ return [
SameTypeConstraint(types.u32, self.index_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b') SameTypeConstraint(types.u32, self.index_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b')
] ]
raise NotImplementedError # FIXME: bytes
# if isinstance(self.type3, types.PlaceholderForType):
# return RequireTypeSubstitutes() if self.type3.name in types.LOOKUP_TABLE:
# return Error(f'{self.type3.name} cannot be subscripted')
# if isinstance(self.index_type3, types.PlaceholderForType):
# return RequireTypeSubstitutes() raise NotImplementedError(self.type3)
#
# if not isinstance(self.type3, types.AppliedType3):
# return Error(f'Cannot subscript {self.type3:s}')
#
# if self.type3.base is types.tuple:
# return None
#
# raise NotImplementedError(self.type3)
#
# def get_new_placeholder_substitutes(self) -> SubstitutionMap:
# if isinstance(self.type3, types.AppliedType3) and self.type3.base is types.tuple and isinstance(self.index_type3, types.PlaceholderForType):
# return {
# self.index_type3: types.u32,
# }
#
# return {}
#
# def substitute_placeholders(self, smap: SubstitutionMap) -> None: # FIXME: Duplicate code
# if isinstance(self.type3, types.PlaceholderForType) and self.type3 in smap: # FIXME: Check recursive?
# self.type3.get_substituted(smap[self.type3])
# self.type3 = smap[self.type3]
#
# if isinstance(self.index_type3, types.PlaceholderForType) and self.index_type3 in smap: # FIXME: Check recursive?
# self.index_type3.get_substituted(smap[self.index_type3])
# self.index_type3 = smap[self.index_type3]
def human_readable(self) -> HumanReadableRet: def human_readable(self) -> HumanReadableRet:
return ( return (

View File

@ -45,10 +45,16 @@ class Type3:
return str(self) return str(self)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
if isinstance(other, PlaceholderForType):
return False
if not isinstance(other, Type3):
raise NotImplementedError raise NotImplementedError
return self is other
def __ne__(self, other: Any) -> bool: def __ne__(self, other: Any) -> bool:
raise NotImplementedError return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
raise NotImplementedError raise NotImplementedError
@ -82,13 +88,16 @@ class PlaceholderForType:
return str(self) return str(self)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
if isinstance(other, Type3):
return False
if not isinstance(other, PlaceholderForType): if not isinstance(other, PlaceholderForType):
raise NotImplementedError raise NotImplementedError
return self is other return self is other
def __ne__(self, other: Any) -> bool: def __ne__(self, other: Any) -> bool:
raise NotImplementedError return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
return 0 # Valid but performs badly return 0 # Valid but performs badly
@ -127,6 +136,22 @@ class AppliedType3(Type3):
self.base = base self.base = base
self.args = args self.args = args
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Type3):
raise NotImplementedError
if not isinstance(other, AppliedType3):
return False
return (
self.base == other.base
and len(self.args) == len(other.args)
and all(
s == x
for s, x in zip(self.args, other.args)
)
)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'AppliedType3({repr(self.base)}, {repr(self.args)})' return f'AppliedType3({repr(self.base)}, {repr(self.args)})'

View File

@ -34,7 +34,7 @@ def testEntry() -> {type_}:
result = Suite(code_py).run_code() result = Suite(code_py).run_code()
assert 24 == result.returned_value assert 57 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value) assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test @pytest.mark.integration_test
@ -95,6 +95,22 @@ def helper(array: {type_}[3]) -> {type_}:
assert 162.25 == result.returned_value assert 162.25 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value) assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
def test_function_call_element():
code_py = """
CONSTANT: u64[3] = (250, 250000, 250000000, )
@exported
def testEntry() -> u8:
return helper(CONSTANT[0])
def helper(x: u8) -> u8:
return x
"""
with pytest.raises(Type3Exception, match=r'u8 must be u64 instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
def test_module_constant_type_mismatch_bitwidth(): def test_module_constant_type_mismatch_bitwidth():
code_py = """ code_py = """
@ -126,7 +142,7 @@ def testEntry() -> u8:
return CONSTANT[0] return CONSTANT[0]
""" """
with pytest.raises(Type3Exception, match='Type cannot be subscripted:'): with pytest.raises(Type3Exception, match='u8 cannot be subscripted'):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test