Adds a separte typing system #3

Closed
jbwdevries wants to merge 18 commits from milner_type_checking into master
12 changed files with 297 additions and 123 deletions
Showing only changes of commit 5da45e78c2 - Show all commits

View File

@ -54,7 +54,7 @@ def expression(inp: ourlang.Expression) -> str:
# could not fit in the given float type
return str(inp.value)
if isinstance(inp, (ourlang.ConstantTuple, ourlang.ConstantStaticArray, )):
if isinstance(inp, ourlang.ConstantTuple):
return '(' + ', '.join(
expression(x)
for x in inp.value
@ -65,8 +65,8 @@ def expression(inp: ourlang.Expression) -> str:
if isinstance(inp, ourlang.UnaryOp):
if (
inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS
or inp.operator in ourlang.WEBASSEMBLY_BUILDIN_BYTES_OPS):
inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS
or inp.operator in ourlang.WEBASSEMBLY_BUILTIN_BYTES_OPS):
return f'{inp.operator}({expression(inp.right)})'
if inp.operator == 'cast':
@ -186,7 +186,7 @@ def module(inp: ourlang.Module) -> str:
for func in inp.functions.values():
if func.lineno < 0:
# Buildin (-2) or auto generated (-1)
# Builtin (-2) or auto generated (-1)
continue
if result:

View File

@ -247,11 +247,11 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
mtyp = typing.simplify(inp.type_var)
if mtyp == 'f32':
if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS:
if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS:
wgn.add_statement(f'f32.{inp.operator}')
return
if mtyp == 'f64':
if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS:
if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS:
wgn.add_statement(f'f64.{inp.operator}')
return
@ -608,7 +608,7 @@ def module_data(inp: ourlang.ModuleData) -> bytes:
data_list.append(module_data_f64(constant.value))
continue
raise NotImplementedError(constant, mtyp)
raise NotImplementedError(constant, constant.type_var, mtyp)
block_data = b''.join(data_list)

View File

@ -7,8 +7,8 @@ import enum
from typing_extensions import Final
WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', )
WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', )
WEBASSEMBLY_BUILTIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', )
WEBASSEMBLY_BUILTIN_BYTES_OPS: Final = ('len', )
from .typing import (
TypeStruct,
@ -57,18 +57,6 @@ class ConstantTuple(Constant):
super().__init__()
self.value = value
class ConstantStaticArray(Constant):
"""
A StaticArray constant value expression within a statement
"""
__slots__ = ('value', )
value: List[ConstantPrimitive]
def __init__(self, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays?
super().__init__()
self.value = value
class VariableReference(Expression):
"""
An variable reference expression within a statement

View File

@ -6,24 +6,22 @@ from typing import Any, Dict, NoReturn, Union
import ast
from .typing import (
BUILTIN_TYPES,
TypeStruct,
TypeStructMember,
TypeTuple,
TypeTupleMember,
TypeStaticArray,
TypeStaticArrayMember,
)
from .exceptions import StaticError
from .ourlang import (
WEBASSEMBLY_BUILDIN_FLOAT_OPS,
WEBASSEMBLY_BUILTIN_FLOAT_OPS,
Module, ModuleDataBlock,
Function,
Expression,
BinaryOp,
ConstantPrimitive, ConstantTuple, ConstantStaticArray,
ConstantPrimitive, ConstantTuple,
FunctionCall, Subscript,
# StructConstructor, TupleConstructor,
@ -482,7 +480,7 @@ class OurVisitor:
# struct_constructor = StructConstructor(struct)
#
# func = module.functions[struct_constructor.name]
if node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS:
if node.func.id in WEBASSEMBLY_BUILTIN_FLOAT_OPS:
if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
@ -686,7 +684,7 @@ class OurVisitor:
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if node.id in ('u8', 'u32', 'u64', 'i32', 'i64', 'f32', 'f64'): # FIXME: Source this list somewhere
if node.id in BUILTIN_TYPES:
return node.id
raise NotImplementedError('TODO: Broken after type system')
@ -697,40 +695,21 @@ class OurVisitor:
_raise_static_error(node, f'Unrecognized type {node.id}')
if isinstance(node, ast.Subscript):
raise NotImplementedError('TODO: Broken after new type system')
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must be name')
if not isinstance(node.slice, ast.Index):
_raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.slice.value, ast.Constant):
_raise_static_error(node, 'Must subscript using a constant index')
if not isinstance(node.slice.value.value, int):
_raise_static_error(node, 'Must subscript using a constant integer index')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
# if not isinstance(node.value, ast.Name):
# _raise_static_error(node, 'Must be name')
# if not isinstance(node.slice, ast.Index):
# _raise_static_error(node, 'Must subscript using an index')
# if not isinstance(node.slice.value, ast.Constant):
# _raise_static_error(node, 'Must subscript using a constant index')
# if not isinstance(node.slice.value.value, int):
# _raise_static_error(node, 'Must subscript using a constant integer index')
# if not isinstance(node.ctx, ast.Load):
# _raise_static_error(node, 'Must be load context')
#
# if node.value.id in module.types:
# member_type = module.types[node.value.id]
# else:
# _raise_static_error(node, f'Unrecognized type {node.value.id}')
#
# type_static_array = TypeStaticArray(member_type)
#
# offset = 0
#
# for idx in range(node.slice.value.value):
# static_array_member = TypeStaticArrayMember(idx, offset)
#
# type_static_array.members.append(static_array_member)
# offset += member_type.alloc_size()
#
# key = f'{node.value.id}[{node.slice.value.value}]'
#
# if key not in module.types:
# module.types[key] = type_static_array
#
# return module.types[key]
if node.value.id not in BUILTIN_TYPES: # FIXME: Tuple of tuples?
_raise_static_error(node, f'Unrecognized type {node.value.id}')
return f'{node.value.id}[{node.slice.value.value}]'
if isinstance(node, ast.Tuple):
raise NotImplementedError('TODO: Broken after new type system')

View File

@ -3,7 +3,13 @@ Type checks and enriches the given ast
"""
from . import ourlang
from .typing import Context, TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeVar, from_str
from .exceptions import TypingError
from .typing import (
Context,
TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeConstraintSubscript,
TypeVar,
from_str,
)
def phasm_type(inp: ourlang.Module) -> None:
module(inp)
@ -55,6 +61,19 @@ def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar:
raise NotImplementedError(constant, inp, inp.value)
if isinstance(inp, ourlang.ConstantTuple):
result = ctx.new_var()
result.add_constraint(TypeConstraintSubscript(members=(
constant(ctx, x)
for x in inp.value
)))
result.add_location(str(inp.value))
inp.type_var = result
return result
raise NotImplementedError(constant, inp)
def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
@ -63,6 +82,8 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
if isinstance(inp, ourlang.VariableReference):
assert inp.variable.type_var is not None
inp.type_var = inp.variable.type_var
return inp.variable.type_var
if isinstance(inp, ourlang.UnaryOp):
@ -112,6 +133,32 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
return inp.function.returns_type_var
if isinstance(inp, ourlang.Subscript):
if not isinstance(inp.index, ourlang.ConstantPrimitive):
raise NotImplementedError(expression, inp, inp.index)
if not isinstance(inp.index.value, int):
raise NotImplementedError(expression, inp, inp.index.value)
expression(ctx, inp.varref)
assert inp.varref.type_var is not None
try:
# TODO: I'd much rather resolve this using the narrow functions
tc_subs = ctx.var_constraints[inp.varref.type_var.ctx_id][TypeConstraintSubscript]
except KeyError:
raise TypingError(f'Type cannot be subscripted: {inp.varref.type_var}') from None
assert isinstance(tc_subs, TypeConstraintSubscript) # type hint
try:
# TODO: I'd much rather resolve this using the narrow functions
member = tc_subs.members[inp.index.value]
except IndexError:
raise TypingError(f'Type cannot be subscripted with index {inp.index.value}: {inp.varref.type_var}') from None
inp.type_var = member
return member
raise NotImplementedError(expression, inp)
def function(ctx: Context, inp: ourlang.Function) -> None:

View File

@ -1,9 +1,10 @@
"""
The phasm type system
"""
from typing import Dict, Iterable, Optional, List, Set, Type
from typing import Callable, Dict, Iterable, Optional, List, Set, Type
import enum
import re
from .exceptions import TypingError
@ -151,7 +152,7 @@ class TypeConstraintBase:
"""
Base class for classes implementing a contraint on a type
"""
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase':
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBase':
raise NotImplementedError('narrow', self, other)
class TypeConstraintPrimitive(TypeConstraintBase):
@ -172,7 +173,7 @@ class TypeConstraintPrimitive(TypeConstraintBase):
def __init__(self, primitive: Primitive) -> None:
self.primitive = primitive
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive':
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive':
if not isinstance(other, TypeConstraintPrimitive):
raise Exception('Invalid comparison')
@ -196,7 +197,7 @@ class TypeConstraintSigned(TypeConstraintBase):
def __init__(self, signed: Optional[bool]) -> None:
self.signed = signed
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintSigned':
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintSigned':
if not isinstance(other, TypeConstraintSigned):
raise Exception('Invalid comparison')
@ -239,7 +240,7 @@ class TypeConstraintBitWidth(TypeConstraintBase):
if x <= maxb
}
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth':
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth':
if not isinstance(other, TypeConstraintBitWidth):
raise Exception('Invalid comparison')
@ -251,7 +252,60 @@ class TypeConstraintBitWidth(TypeConstraintBase):
return TypeConstraintBitWidth(oneof=new_oneof)
def __repr__(self) -> str:
return 'BitWidth=oneof(' + ','.join(map(str, sorted(self.oneof))) + ')'
result = 'BitWidth='
items = list(sorted(self.oneof))
if not items:
return result
while items:
itm = items.pop(0)
result += str(itm)
cnt = 0
while cnt < len(items) and items[cnt] == itm + cnt + 1:
cnt += 1
if cnt == 1:
result += ',' + str(items[0])
elif cnt > 1:
result += '..' + str(items[cnt - 1])
items = items[cnt:]
if items:
result += ','
return result
class TypeConstraintSubscript(TypeConstraintBase):
"""
Contraint on allowing a type to be subscripted
"""
__slots__ = ('members', )
members: List['TypeVar']
def __init__(self, *, members: Iterable['TypeVar']) -> None:
self.members = list(members)
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintSubscript':
if not isinstance(other, TypeConstraintSubscript):
raise Exception('Invalid comparison')
if len(self.members) != len(other.members):
raise TypingNarrowProtoError('Member count does not match')
newmembers = []
for smb, omb in zip(self.members, other.members):
nmb = ctx.new_var()
ctx.unify(nmb, smb)
ctx.unify(nmb, omb)
newmembers.append(nmb)
return TypeConstraintSubscript(members=newmembers)
def __repr__(self) -> str:
return 'Subscript=(' + ','.join(map(repr, self.members)) + ')'
class TypeVar:
"""
@ -271,7 +325,7 @@ class TypeVar:
csts = self.ctx.var_constraints[self.ctx_id]
if newconst.__class__ in csts:
csts[newconst.__class__] = csts[newconst.__class__].narrow(newconst)
csts[newconst.__class__] = csts[newconst.__class__].narrow(self.ctx, newconst)
else:
csts[newconst.__class__] = newconst
@ -413,6 +467,93 @@ def simplify(inp: TypeVar) -> Optional[str]:
return None
def make_u8(ctx: Context, location: str) -> TypeVar:
"""
Makes a u8 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
def make_u32(ctx: Context, location: str) -> TypeVar:
"""
Makes a u32 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
def make_u64(ctx: Context, location: str) -> TypeVar:
"""
Makes a u64 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
def make_i32(ctx: Context, location: str) -> TypeVar:
"""
Makes a i32 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
def make_i64(ctx: Context, location: str) -> TypeVar:
"""
Makes a i64 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
def make_f32(ctx: Context, location: str) -> TypeVar:
"""
Makes a f32 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_location(location)
return result
def make_f64(ctx: Context, location: str) -> TypeVar:
"""
Makes a f64 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_location(location)
return result
BUILTIN_TYPES: Dict[str, Callable[[Context, str], TypeVar]] = {
'u8': make_u8,
'u32': make_u32,
'u64': make_u64,
'i32': make_i32,
'i64': make_i64,
'f32': make_f32,
'f64': make_f64,
}
TYPE_MATCH_STATIC_ARRAY = re.compile(r'^([uif][0-9]+)\[([0-9]+)\]')
def from_str(ctx: Context, inp: str, location: str) -> TypeVar:
"""
Creates a new TypeVar from the string
@ -425,53 +566,21 @@ def from_str(ctx: Context, inp: str, location: str) -> TypeVar:
This could be conidered part of parsing. Though that would give trouble
with the context creation.
"""
if inp in BUILTIN_TYPES:
return BUILTIN_TYPES[inp](ctx, location)
match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp)
if match:
result = ctx.new_var()
if inp == 'u8':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False))
result.add_constraint(TypeConstraintSubscript(members=(
# Make copies so they don't get entangled
# with each other.
from_str(ctx, match[1], match[1])
for _ in range(int(match[2]))
)))
result.add_location(location)
return result
if inp == 'u32':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if inp == 'u64':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if inp == 'i32':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
if inp == 'i64':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
if inp == 'f32':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_location(location)
return result
if inp == 'f64':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_location(location)
return result
raise NotImplementedError(from_str, inp)

View File

@ -7,4 +7,4 @@ max-line-length=180
good-names=g
[tests]
disable=C0116,
disable=C0116,R0201

View File

@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, Iterable, Optional, TextIO
import ctypes
import io
import warnings
import pywasm.binary
import wasm3
@ -42,10 +41,7 @@ class RunnerBase:
Parses the Phasm code into an AST
"""
self.phasm_ast = phasm_parse(self.phasm_code)
try:
phasm_type(self.phasm_ast)
except NotImplementedError as exc:
warnings.warn(f'phasm_type throws an NotImplementedError on this test: {exc}')
def compile_ast(self) -> None:
"""

View File

View File

@ -0,0 +1,20 @@
import pytest
from phasm import typing as sut
class TestTypeConstraintBitWidth:
@pytest.mark.parametrize('oneof,exp', [
(set(), '', ),
({1}, '1', ),
({1,2}, '1,2', ),
({1,2,3}, '1..3', ),
({1,2,3,4}, '1..4', ),
({1,3}, '1,3', ),
({1,4}, '1,4', ),
({1,2,3,4,6,7,8,9}, '1..4,6..9', ),
])
def test_repr(self, oneof, exp):
mut_self = sut.TypeConstraintBitWidth(oneof=oneof)
assert ('BitWidth=' + exp) == repr(mut_self)

View File

@ -219,7 +219,7 @@ def testEntry() -> {type_}:
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['f32', 'f64'])
def test_buildins_sqrt(type_):
def test_builtins_sqrt(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:

View File

@ -1,12 +1,12 @@
import pytest
from phasm.exceptions import StaticError
from phasm.exceptions import StaticError, TypingError
from ..constants import COMPLETE_PRIMITIVE_TYPES, TYPE_MAP
from ..constants import ALL_INT_TYPES, COMPLETE_PRIMITIVE_TYPES, TYPE_MAP
from ..helpers import Suite
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES)
@pytest.mark.parametrize('type_', ALL_INT_TYPES)
def test_module_constant(type_):
code_py = f"""
CONSTANT: {type_}[3] = (24, 57, 80, )
@ -59,6 +59,41 @@ def helper(array: {type_}[3]) -> {type_}:
assert 161 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
def test_module_constant_type_mismatch_bitwidth():
code_py = """
CONSTANT: u8[3] = (24, 57, 280, )
"""
with pytest.raises(TypingError, match='u8.*280'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_module_constant_type_mismatch_not_subscriptable():
code_py = """
CONSTANT: u8 = 24
@exported
def testEntry() -> u8:
return CONSTANT[0]
"""
with pytest.raises(TypingError, match='Type cannot be subscripted:'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_module_constant_type_mismatch_index_out_of_range():
code_py = """
CONSTANT: u8[3] = (24, 57, 80, )
@exported
def testEntry() -> u8:
return CONSTANT[3]
"""
with pytest.raises(TypingError, match='Type cannot be subscripted with index 3:'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_static_array_constant_too_few_values():
code_py = """