Trying out some things regarding StaticArray

This commit is contained in:
Johan B.W. de Vries 2022-11-16 13:50:03 +01:00
parent 42c9ff6ca7
commit 55a45ff17c
4 changed files with 72 additions and 26 deletions

View File

@ -175,6 +175,7 @@ def module_constant_def(ctx: Context, inp: ourlang.ModuleConstantDef) -> None:
inp.type_var = from_str(ctx, inp.type_str) inp.type_var = from_str(ctx, inp.type_str)
assert inp.constant.type_var is not None assert inp.constant.type_var is not None
# This doesn't work sufficiently with StaticArray
ctx.unify(inp.type_var, inp.constant.type_var) ctx.unify(inp.type_var, inp.constant.type_var)
def module(inp: ourlang.Module) -> None: def module(inp: ourlang.Module) -> None:

View File

@ -1,7 +1,7 @@
""" """
The phasm type system The phasm type system
""" """
from typing import Callable, Dict, Iterable, Optional, List, Set, Type from typing import Any, Callable, Dict, Iterable, Optional, List, Set, Type, Union
from typing import TypeVar as MyPyTypeVar from typing import TypeVar as MyPyTypeVar
import enum import enum
@ -132,17 +132,55 @@ class TypeStruct(TypeBase):
ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method' ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method'
class PhasmType: class PhasmType:
__slots__ = ('name', ) __slots__ = ('name', 'args', 'arg_count', )
name: str
def __init__(self, name: str) -> None: name: str
args: List[Union['PhasmType', 'TypeVar']]
arg_count: int
def __init__(self, name: str, arg_count: int = 0) -> None:
self.name = name self.name = name
self.args = []
self.arg_count = arg_count
def __call__(self, type_arg: Union['PhasmType', 'TypeVar']) -> 'PhasmType':
assert 0 < self.arg_count
result = PhasmType(self.name, self.arg_count - 1)
result.args = self.args + [type_arg]
return result
def __eq__(self, other: Any) -> bool:
if not isinstance(other, PhasmType):
raise NotImplementedError
return (
self.name == other.name
and self.args == other.args
and self.arg_count == other.arg_count
)
def __ne__(self, other: Any) -> bool:
if not isinstance(other, PhasmType):
raise NotImplementedError
return (
self.name != other.name
or self.args != other.args
or self.arg_count != other.arg_count
)
def __repr__(self) -> str: def __repr__(self) -> str:
return 'PhasmType' + self.name return (
'PhasmType' + self.name + ' '
+ ' '.join(map(repr, self.args)) + ' '
+ ' '.join(['?'] * self.arg_count)
).strip()
PhasmTypeInteger = PhasmType('Integer') PhasmTypeInteger = PhasmType('Integer')
PhasmTypeReal = PhasmType('Real') PhasmTypeReal = PhasmType('Real')
PhasmTypeStaticArray = PhasmType('StaticArray', 1)
class TypingNarrowProtoError(TypingError): class TypingNarrowProtoError(TypingError):
""" """
@ -331,6 +369,12 @@ class TypeVar:
def add_location(self, ref: str) -> None: def add_location(self, ref: str) -> None:
self.ctx.var_locations[self.ctx_id].add(ref) self.ctx.var_locations[self.ctx_id].add(ref)
def __eq__(self, other: Any) -> bool:
raise NotImplementedError
def __ne__(self, other: Any) -> bool:
raise NotImplementedError
def __repr__(self) -> str: def __repr__(self) -> str:
typ = self.ctx.var_types[self.ctx_id] typ = self.ctx.var_types[self.ctx_id]
@ -396,8 +440,7 @@ class Context:
if l_type is not None and r_type is not None and l_type != r_type: if l_type is not None and r_type is not None and l_type != r_type:
raise TypingNarrowError(l, r, 'Type does not match') raise TypingNarrowError(l, r, 'Type does not match')
else: self.var_types[n.ctx_id] = l_type
self.var_types[n.ctx_id] = l_type
try: try:
for const in self.var_constraints[l_ctx_id].values(): for const in self.var_constraints[l_ctx_id].values():
@ -579,23 +622,22 @@ def from_str(ctx: Context, inp: str, location: Optional[str] = None) -> TypeVar:
result.add_location(location) result.add_location(location)
return result return result
# match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp) match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp)
# if match: if match:
# result = ctx.new_var() result = ctx.new_var(PhasmTypeStaticArray)
#
# result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY)) result.add_constraint(TypeConstraintSubscript(members=(
# result.add_constraint(TypeConstraintSubscript(members=( # Make copies so they don't get entangled
# # Make copies so they don't get entangled # with each other.
# # with each other. from_str(ctx, match[1], match[1])
# from_str(ctx, match[1], match[1]) for _ in range(int(match[2]))
# for _ in range(int(match[2])) )))
# )))
# result.add_location(inp)
# result.add_location(inp)
# if location is not None:
# if location is not None: result.add_location(location)
# result.add_location(location)
# return result
# return result
raise NotImplementedError(from_str, inp) raise NotImplementedError(from_str, inp)

View File

@ -1,6 +1,6 @@
import pytest import pytest
from phasm.exceptions import StaticError, TypingError from phasm.exceptions import TypingError
from ..constants import ( from ..constants import (
ALL_FLOAT_TYPES, ALL_INT_TYPES, COMPLETE_INT_TYPES, COMPLETE_NUMERIC_TYPES, TYPE_MAP ALL_FLOAT_TYPES, ALL_INT_TYPES, COMPLETE_INT_TYPES, COMPLETE_NUMERIC_TYPES, TYPE_MAP

View File

@ -1,5 +1,8 @@
import pytest import pytest
from phasm.exceptions import StaticError
from phasm.parser import phasm_parse
from ..helpers import Suite from ..helpers import Suite
@pytest.mark.integration_test @pytest.mark.integration_test