From 3cb48609733ba302bfc15d521ca5d5030461f5b5 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Mon, 2 Jun 2025 19:01:20 +0200 Subject: [PATCH] Subscriptable is now less hardcoded Now only the tuple variant is hardcoded. The rest is via a typeclass. --- TODO.md | 2 - phasm/build/default.py | 3 +- phasm/build/typeclasses/subscriptable.py | 147 ++++++++++++++++++ phasm/compiler.py | 101 ++---------- phasm/type3/constraints.py | 65 ++++---- phasm/type3/functions.py | 2 +- .../test_lang/test_subscriptable.py | 2 +- 7 files changed, 201 insertions(+), 121 deletions(-) create mode 100644 phasm/build/typeclasses/subscriptable.py diff --git a/TODO.md b/TODO.md index 2e8c01c..0757ab3 100644 --- a/TODO.md +++ b/TODO.md @@ -12,8 +12,6 @@ - Allocation is done using pointers for members, is this desired? - See if we want to replace Fractional with Real, and add Rational, Irrationl, Algebraic, Transendental - Implement q32? q64? Two i32/i64 divided? -- Does Subscript do what we want? It's a language feature rather a normal typed thing. How would you implement your own Subscript-able type? - - Clean up Subscript implementation - it's half implemented in the compiler. Makes more sense to move more parts to stdlib_types. - Have a set of rules or guidelines for the constraint comments, they're messy. - calculate_alloc_size can be reworked; is_member isn't useful with TYPE_INFO_MAP diff --git a/phasm/build/default.py b/phasm/build/default.py index f40536c..2cbcd8f 100644 --- a/phasm/build/default.py +++ b/phasm/build/default.py @@ -34,6 +34,7 @@ from .typeclasses import ( promotable, reinterpretable, sized, + subscriptable, ) @@ -88,7 +89,7 @@ class BuildDefault(BuildBase[Generator]): convertable, reinterpretable, natnum, intnum, fractional, floating, integral, - foldable, + foldable, subscriptable, sized, ] diff --git a/phasm/build/typeclasses/subscriptable.py b/phasm/build/typeclasses/subscriptable.py new file mode 100644 index 0000000..1caaa55 --- /dev/null +++ b/phasm/build/typeclasses/subscriptable.py @@ -0,0 +1,147 @@ +""" +The Eq type class is defined for types that can be compered based on equality. +""" +from typing import Any + +from ...type3.functions import TypeConstructorVariable, make_typevar +from ...type3.routers import TypeVariableLookup +from ...type3.typeclasses import Type3Class +from ...type3.types import IntType3, Type3 +from ...wasmgenerator import Generator as WasmGenerator +from ..base import BuildBase + + +def load(build: BuildBase[Any]) -> None: + a = make_typevar('a') + t = TypeConstructorVariable('t') + u32 = build.types['u32'] + + Subscriptable = Type3Class('Subscriptable', (t, ), methods={}, operators={ + '[]': [t(a), u32, a], + }) + + build.register_type_class(Subscriptable) + +class SubscriptableCodeGenerator: + def __init__(self, build: BuildBase[WasmGenerator]) -> None: + self.build = build + + def wasm_dynamic_array_getitem(self, g: WasmGenerator, tvl: TypeVariableLookup) -> None: + tv_map, tc_map = tvl + + tvn_map = { + x.name: y + for x, y in tv_map.items() + } + + sa_type = tvn_map['a'] + + assert isinstance(sa_type, Type3) + + ptr_type_info = self.build.type_info_map['ptr'] + u32_type_info = self.build.type_info_map['u32'] + + sa_type_info = self.build.type_info_map.get(sa_type.name) + if sa_type_info is None: + sa_type_info = ptr_type_info + + getitem_adr = g.temp_var_t(u32_type_info.wasm_type, 'getitem_adr') + getitem_idx = g.temp_var_t(u32_type_info.wasm_type, 'getitem_idx') + + # Stack: [varref: *ard, idx: u32] + g.local.set(getitem_idx) + # Stack: [varref: *ard] + g.local.set(getitem_adr) + # Stack: [] + + # Out of bounds check based on memory stored length + # Stack: [] + g.local.get(getitem_idx) + # Stack: [idx: u32] + g.local.get(getitem_adr) + # Stack: [idx: u32, varref: *ard] + g.i32.load() + # Stack: [idx: u32, len: u32] + g.i32.ge_u() + # Stack: [res: bool] + with g.if_(): + g.unreachable(comment='Out of bounds') + + # Stack: [] + g.local.get(getitem_adr) + # Stack: [varref: *ard] + g.i32.const(4) + # Stack: [varref: *ard, 4] + g.i32.add() + # Stack: [firstel: *ard] + g.local.get(getitem_idx) + # Stack: [firstel: *ard, idx: u32] + g.i32.const(sa_type_info.alloc_size) + # Stack: [firstel: *ard, idx: u32, as: u32] + g.i32.mul() + # Stack: [firstel: *ard, offset: u32] + g.i32.add() + # Stack: [eladr: *ard] + g.add_statement(sa_type_info.wasm_load_func) + # Stack: [el] + + def wasm_static_array_getitem(self, g: WasmGenerator, tvl: TypeVariableLookup) -> None: + tv_map, tc_map = tvl + + tvn_map = { + x.name: y + for x, y in tv_map.items() + } + + sa_type = tvn_map['a'] + sa_len = tvn_map['a*'] + + assert isinstance(sa_type, Type3) + assert isinstance(sa_len, IntType3) + + ptr_type_info = self.build.type_info_map['ptr'] + u32_type_info = self.build.type_info_map['u32'] + + sa_type_info = self.build.type_info_map.get(sa_type.name) + if sa_type_info is None: + sa_type_info = ptr_type_info + + # OPTIMIZE: If index is a constant, we can use offset instead of multiply + # and we don't need to do the out of bounds check + getitem_idx = g.temp_var_t(u32_type_info.wasm_type, 'getitem_idx') + + # Stack: [varref: *ard, idx: u32] + g.local.tee(getitem_idx) + + # Stack: [varref: *ard, idx: u32] + # Out of bounds check based on sa_len.value + g.i32.const(sa_len.value) + # Stack: [varref: *ard, idx: u32, len: u32] + g.i32.ge_u() + # Stack: [varref: *ard, res: bool] + with g.if_(): + g.unreachable(comment='Out of bounds') + + # Stack: [varref: *ard] + g.local.get(getitem_idx) + # Stack: [varref: *ard, idx: u32] + g.i32.const(sa_type_info.alloc_size) + # Stack: [varref: *ard, idx: u32, as: u32] + g.i32.mul() + # Stack: [varref: *ard, offset: u32] + g.i32.add() + # Stack: [eladr: *ard] + g.add_statement(sa_type_info.wasm_load_func) + # Stack: [el] + +def wasm(build: BuildBase[WasmGenerator]) -> None: + Subscriptable = build.type_classes['Subscriptable'] + + gen = SubscriptableCodeGenerator(build) + + build.instance_type_class(Subscriptable, build.dynamic_array, operators={ + '[]': gen.wasm_dynamic_array_getitem, + }) + build.instance_type_class(Subscriptable, build.static_array, operators={ + '[]': gen.wasm_static_array_getitem, + }) diff --git a/phasm/compiler.py b/phasm/compiler.py index 327fedb..4851e95 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -133,87 +133,13 @@ def tuple_instantiation(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], # Return the allocated address wgn.local.get(tmp_var) -def expression_subscript_dynamic_array( - attrs: tuple[WasmGenerator, ourlang.Module[WasmGenerator], ourlang.Subscript], - args: tuple[Type3], - ) -> None: - wgn, mod, inp = attrs - el_type, = args - - el_type_info = mod.build.type_info_map.get(el_type.name) - if el_type_info is None: - el_type_info = mod.build.type_info_map['ptr'] - - tmp_idx = wgn.temp_var_i32('tmp_idx') - tmp_adr = wgn.temp_var_i32('tmp_adr') - - expression(wgn, mod, inp.varref) - expression(wgn, mod, inp.index) - - wgn.local.set(tmp_idx) - wgn.local.set(tmp_adr) - - # Out of bounds check based on size stored in memory - wgn.local.get(tmp_idx) - wgn.local.get(tmp_adr) - wgn.i32.load() - wgn.i32.ge_u() - with wgn.if_(): - wgn.unreachable(comment='Out of bounds') - - # tmp_ard + 4 + (tmp_idx * alloc_size) - wgn.local.get(tmp_adr) - wgn.i32.const(4) - wgn.i32.add() - wgn.local.get(tmp_idx) - wgn.i32.const(el_type_info.alloc_size) - wgn.i32.mul() - wgn.i32.add() - - wgn.add_statement(el_type_info.wasm_load_func) - -def expression_subscript_static_array( - attrs: tuple[WasmGenerator, ourlang.Module[WasmGenerator], ourlang.Subscript], - args: tuple[Type3, IntType3], - ) -> None: - wgn, mod, inp = attrs - - el_type, el_len = args - - # OPTIMIZE: If index is a constant, we can use offset instead of multiply - # and we don't need to do the out of bounds check - tmp_var = wgn.temp_var_i32('index') - - expression(wgn, mod, inp.varref) - expression(wgn, mod, inp.index) - - wgn.local.tee(tmp_var) - - # Out of bounds check based on el_len.value - wgn.i32.const(el_len.value) - wgn.i32.ge_u() - with wgn.if_(): - wgn.unreachable(comment='Out of bounds') - - el_type_info = mod.build.type_info_map.get(el_type.name) - if el_type_info is None: - el_type_info = mod.build.type_info_map['ptr'] - - wgn.local.get(tmp_var) - wgn.i32.const(el_type_info.alloc_size) - wgn.i32.mul() - wgn.i32.add() - - wgn.add_statement(el_type_info.wasm_load_func) - -def expression_subscript_tuple( - attrs: tuple[WasmGenerator, ourlang.Module[WasmGenerator], ourlang.Subscript], - args: tuple[Type3, ...], - ) -> None: - wgn, mod, inp = attrs - +def expression_subscript_tuple(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.Subscript) -> None: assert isinstance(inp.index, ourlang.ConstantPrimitive) assert isinstance(inp.index.value, int) + assert inp.varref.type3 is not None, TYPE3_ASSERTION_ERROR + assert isinstance(inp.varref.type3.application, TypeApplication_TypeStar) + + args = inp.varref.type3.application.arguments offset = 0 for el_type in args[0:inp.index.value]: @@ -233,11 +159,6 @@ def expression_subscript_tuple( el_type_info = mod.build.type_info_map['ptr'] wgn.add_statement(el_type_info.wasm_load_func, f'offset={offset}') -SUBSCRIPT_ROUTER = TypeApplicationRouter[tuple[WasmGenerator, ourlang.Module[WasmGenerator], ourlang.Subscript], None]() -SUBSCRIPT_ROUTER.add(builtins.dynamic_array, expression_subscript_dynamic_array) -SUBSCRIPT_ROUTER.add(builtins.static_array, expression_subscript_static_array) -SUBSCRIPT_ROUTER.add(builtins.tuple_, expression_subscript_tuple) - def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.Expression) -> None: """ Compile: Any expression @@ -388,10 +309,18 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourl return if isinstance(inp, ourlang.Subscript): + assert inp.type3 is not None, TYPE3_ASSERTION_ERROR assert inp.varref.type3 is not None, TYPE3_ASSERTION_ERROR - # Type checker guarantees we don't get routing errors - SUBSCRIPT_ROUTER((wgn, mod, inp, ), inp.varref.type3) + if inp.varref.type3.application.constructor is mod.build.tuple_: + expression_subscript_tuple(wgn, mod, inp) + return + + inp_as_fc = ourlang.FunctionCall(mod.build.type_classes['Subscriptable'].operators['[]']) + inp_as_fc.type3 = inp.type3 + inp_as_fc.arguments = [inp.varref, inp.index] + + expression(wgn, mod, inp_as_fc) return if isinstance(inp, ourlang.AccessStructMember): diff --git a/phasm/type3/constraints.py b/phasm/type3/constraints.py index e2304a0..b8c71c4 100644 --- a/phasm/type3/constraints.py +++ b/phasm/type3/constraints.py @@ -634,34 +634,9 @@ class CanBeSubscriptedConstraint(ConstraintBase): self.index_type3 = index_type3 self.index_const = index_const - - self.generate_router = TypeApplicationRouter() - self.generate_router.add(context.build.dynamic_array, self.__class__._generate_dynamic_array) - self.generate_router.add(context.build.static_array, self.__class__._generate_static_array) - self.generate_router.add(context.build.tuple_, self.__class__._generate_tuple) - - def _generate_dynamic_array(self, da_args: tuple[Type3]) -> CheckResult: - da_type, = da_args - - return [ - SameTypeConstraint(self.context, self.context.build.types['u32'], self.index_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'), - SameTypeConstraint(self.context, da_type, self.ret_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'), - ] - - def _generate_static_array(self, sa_args: tuple[Type3, IntType3]) -> CheckResult: - sa_type, sa_len = sa_args - - if self.index_const is not None and (self.index_const < 0 or sa_len.value <= self.index_const): - return Error('Tuple index out of range') - - return [ - SameTypeConstraint(self.context, self.context.build.types['u32'], self.index_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'), - SameTypeConstraint(self.context, sa_type, self.ret_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'), - ] - def _generate_tuple(self, tp_args: tuple[Type3, ...]) -> CheckResult: # We special case tuples to allow for ease of use to the programmer - # e.g. rather than having to do `fst a` and `snd a` and only have to-sized tuples + # e.g. rather than having to do `fst a` and `snd a` and only have tuples of size 2 # we use a[0] and a[1] and allow for a[2] and on. if self.index_const is None: @@ -681,10 +656,40 @@ class CanBeSubscriptedConstraint(ConstraintBase): exp_type = self.type3.resolve_as - try: - return self.generate_router(self, exp_type) - except NoRouteForTypeException: - return Error(f'{exp_type.name} cannot be subscripted') + if exp_type.application.constructor == self.context.build.tuple_: + return self._generate_tuple(exp_type.application.arguments) + + result: NewConstraintList = [] + result.extend([ + MustImplementTypeClassConstraint( + self.context, + self.context.build.type_classes['Subscriptable'], + [exp_type], + ), + SameTypeConstraint( + self.context, + self.context.build.types['u32'], + self.index_type3, + ), + ]) + + if isinstance(exp_type.application, (TypeApplication_Type, TypeApplication_TypeInt, )): + result.extend([ + SameTypeConstraint( + self.context, + exp_type.application.arguments[0], + self.ret_type3, + ), + ]) + # else: The MustImplementTypeClassConstraint will catch this + + if exp_type.application.constructor == self.context.build.static_array: + _, sa_len = exp_type.application.arguments + + if self.index_const is not None and (self.index_const < 0 or sa_len.value <= self.index_const): + return Error('Tuple index out of range') + + return result def human_readable(self) -> HumanReadableRet: return ( diff --git a/phasm/type3/functions.py b/phasm/type3/functions.py index bbd8df9..d10e2b1 100644 --- a/phasm/type3/functions.py +++ b/phasm/type3/functions.py @@ -100,7 +100,7 @@ class TypeConstructorVariable: return False if not isinstance(other, TypeConstructorVariable): - raise NotImplementedError + raise NotImplementedError(other) return (self.name == other.name) diff --git a/tests/integration/test_lang/test_subscriptable.py b/tests/integration/test_lang/test_subscriptable.py index 7cdce03..d59ecc2 100644 --- a/tests/integration/test_lang/test_subscriptable.py +++ b/tests/integration/test_lang/test_subscriptable.py @@ -147,5 +147,5 @@ def testEntry(x: u8) -> u8: return x[0] """ - with pytest.raises(Type3Exception, match='u8 cannot be subscripted'): + with pytest.raises(Type3Exception, match='Missing type class instantation: Subscriptable u8'): Suite(code_py).run_code()