Adds a separte typing system #3

Closed
jbwdevries wants to merge 18 commits from milner_type_checking into master
6 changed files with 163 additions and 189 deletions
Showing only changes of commit 0097ce782d - Show all commits

View File

@ -95,26 +95,23 @@ def expression(inp: ourlang.Expression) -> str:
# return f'({args}, )'
return f'{inp.function.name}({args})'
if isinstance(inp, ourlang.AccessBytesIndex):
return f'{expression(inp.varref)}[{expression(inp.index)}]'
if isinstance(inp, ourlang.AccessStructMember):
return f'{expression(inp.varref)}.{inp.member.name}'
if isinstance(inp, (ourlang.AccessTupleMember, ourlang.AccessStaticArrayMember, )):
if isinstance(inp.member, ourlang.Expression):
return f'{expression(inp.varref)}[{expression(inp.member)}]'
return f'{expression(inp.varref)}[{inp.member.idx}]'
#
# if isinstance(inp, ourlang.AccessBytesIndex):
# return f'{expression(inp.varref)}[{expression(inp.index)}]'
#
# if isinstance(inp, ourlang.AccessStructMember):
# return f'{expression(inp.varref)}.{inp.member.name}'
#
# if isinstance(inp, (ourlang.AccessTupleMember, ourlang.AccessStaticArrayMember, )):
# if isinstance(inp.member, ourlang.Expression):
# return f'{expression(inp.varref)}[{expression(inp.member)}]'
#
# return f'{expression(inp.varref)}[{inp.member.idx}]'
if isinstance(inp, ourlang.Fold):
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'
return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})'
if isinstance(inp, ourlang.ModuleConstantReference):
return inp.definition.name
raise NotImplementedError(expression, inp)
def statement(inp: ourlang.Statement) -> Statements:

View File

@ -156,8 +156,39 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
raise NotImplementedError(f'Constants with type {stp}')
if isinstance(inp, ourlang.VariableReference):
wgn.add_statement('local.get', '${}'.format(inp.variable.name))
return
if isinstance(inp.variable, ourlang.FunctionParam):
wgn.add_statement('local.get', '${}'.format(inp.variable.name))
return
if isinstance(inp.variable, ourlang.ModuleConstantDef):
# FIXME: Tuple / Static Array broken after new type system
# if isinstance(inp.type, typing.TypeTuple):
# assert isinstance(inp.definition.constant, ourlang.ConstantTuple)
# assert inp.definition.data_block is not None, 'Combined values are memory stored'
# assert inp.definition.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.definition.data_block.address)
# return
#
# if isinstance(inp.type, typing.TypeStaticArray):
# assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray)
# assert inp.definition.data_block is not None, 'Combined values are memory stored'
# assert inp.definition.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.definition.data_block.address)
# return
assert inp.variable.data_block is None, 'Primitives are not memory stored'
assert inp.variable.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.variable.type_var)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, inp.type_var)
expression(wgn, inp.variable.constant)
return
raise NotImplementedError(expression, inp.variable)
if isinstance(inp, ourlang.BinaryOp):
expression(wgn, inp.left)
@ -301,34 +332,6 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
expression_fold(wgn, inp)
return
if isinstance(inp, ourlang.ModuleConstantReference):
# FIXME: Tuple / Static Array broken after new type system
# if isinstance(inp.type, typing.TypeTuple):
# assert isinstance(inp.definition.constant, ourlang.ConstantTuple)
# assert inp.definition.data_block is not None, 'Combined values are memory stored'
# assert inp.definition.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.definition.data_block.address)
# return
#
# if isinstance(inp.type, typing.TypeStaticArray):
# assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray)
# assert inp.definition.data_block is not None, 'Combined values are memory stored'
# assert inp.definition.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.definition.data_block.address)
# return
assert inp.definition.data_block is None, 'Primitives are not memory stored'
assert inp.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.type_var)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, inp.type_var)
expression(wgn, inp.definition.constant)
return
raise NotImplementedError(expression, inp)
def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
@ -566,38 +569,46 @@ def module_data(inp: ourlang.ModuleData) -> bytes:
data_list: List[bytes] = []
raise NotImplementedError('Broken after new type system')
for constant in block.data:
if isinstance(constant, ourlang.ConstantUInt8):
assert constant.type_var is not None
mtyp = typing.simplify(constant.type_var)
if mtyp == 'u8':
assert isinstance(constant.value, int)
data_list.append(module_data_u8(constant.value))
continue
if isinstance(constant, ourlang.ConstantUInt32):
if mtyp == 'u32':
assert isinstance(constant.value, int)
data_list.append(module_data_u32(constant.value))
continue
if isinstance(constant, ourlang.ConstantUInt64):
if mtyp == 'u64':
assert isinstance(constant.value, int)
data_list.append(module_data_u64(constant.value))
continue
if isinstance(constant, ourlang.ConstantInt32):
if mtyp == 'i32':
assert isinstance(constant.value, int)
data_list.append(module_data_i32(constant.value))
continue
if isinstance(constant, ourlang.ConstantInt64):
if mtyp == 'i64':
assert isinstance(constant.value, int)
data_list.append(module_data_i64(constant.value))
continue
if isinstance(constant, ourlang.ConstantFloat32):
if mtyp == 'f32':
assert isinstance(constant.value, float)
data_list.append(module_data_f32(constant.value))
continue
if isinstance(constant, ourlang.ConstantFloat64):
if mtyp == 'f64':
assert isinstance(constant.value, float)
data_list.append(module_data_f64(constant.value))
continue
raise NotImplementedError(constant)
raise NotImplementedError(constant, mtyp)
block_data = b''.join(data_list)

View File

@ -11,10 +11,7 @@ WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc',
WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', )
from .typing import (
TypeBytes,
TypeTuple, TypeTupleMember,
TypeStaticArray, TypeStaticArrayMember,
TypeStruct, TypeStructMember,
TypeStruct,
TypeVar,
)
@ -78,9 +75,9 @@ class VariableReference(Expression):
"""
__slots__ = ('variable', )
variable: 'FunctionParam' # also possibly local
variable: Union['ModuleConstantDef', 'FunctionParam'] # also possibly local
def __init__(self, variable: 'FunctionParam') -> None:
def __init__(self, variable: Union['ModuleConstantDef', 'FunctionParam']) -> None:
super().__init__()
self.variable = variable
@ -131,9 +128,10 @@ class FunctionCall(Expression):
self.function = function
self.arguments = []
class AccessBytesIndex(Expression):
class Subscript(Expression):
"""
Access a bytes index for reading
A subscript, for example to refer to a static array or tuple
by index
"""
__slots__ = ('varref', 'index', )
@ -146,53 +144,6 @@ class AccessBytesIndex(Expression):
self.varref = varref
self.index = index
class AccessStructMember(Expression):
"""
Access a struct member for reading of writing
"""
__slots__ = ('varref', 'member', )
varref: VariableReference
member: TypeStructMember
def __init__(self, varref: VariableReference, member: TypeStructMember) -> None:
super().__init__()
self.varref = varref
self.member = member
class AccessTupleMember(Expression):
"""
Access a tuple member for reading of writing
"""
__slots__ = ('varref', 'member', )
varref: VariableReference
member: TypeTupleMember
def __init__(self, varref: VariableReference, member: TypeTupleMember, ) -> None:
super().__init__()
self.varref = varref
self.member = member
class AccessStaticArrayMember(Expression):
"""
Access a tuple member for reading of writing
"""
__slots__ = ('varref', 'static_array', 'member', )
varref: Union['ModuleConstantReference', VariableReference]
static_array: TypeStaticArray
member: Union[Expression, TypeStaticArrayMember]
def __init__(self, varref: Union['ModuleConstantReference', VariableReference], static_array: TypeStaticArray, member: Union[TypeStaticArrayMember, Expression], ) -> None:
super().__init__()
self.varref = varref
self.static_array = static_array
self.member = member
class Fold(Expression):
"""
A (left or right) fold
@ -223,18 +174,6 @@ class Fold(Expression):
self.base = base
self.iter = iter_
class ModuleConstantReference(Expression):
"""
An reference to a module constant expression within a statement
"""
__slots__ = ('definition', )
definition: 'ModuleConstantDef'
def __init__(self, definition: 'ModuleConstantDef') -> None:
super().__init__()
self.definition = definition
class Statement:
"""
A statement within a function

View File

@ -22,15 +22,14 @@ from .ourlang import (
Function,
Expression,
AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember,
BinaryOp,
ConstantPrimitive, ConstantTuple, ConstantStaticArray,
FunctionCall,
FunctionCall, Subscript,
# StructConstructor, TupleConstructor,
UnaryOp, VariableReference,
Fold, ModuleConstantReference,
Fold,
Statement,
StatementIf, StatementPass, StatementReturn,
@ -206,6 +205,27 @@ class OurVisitor:
None,
)
if isinstance(node.value, ast.Tuple):
tuple_data = [
self.visit_Module_Constant(module, arg_node)
for arg_node in node.value.elts
if isinstance(arg_node, ast.Constant)
]
if len(node.value.elts) != len(tuple_data):
_raise_static_error(node, 'Tuple arguments must be constants')
# Allocate the data
data_block = ModuleDataBlock(tuple_data)
module.data.blocks.append(data_block)
# Then return the constant as a pointer
return ModuleConstantDef(
node.target.id,
node.lineno,
ConstantTuple(tuple_data),
data_block,
)
raise NotImplementedError('TODO: Broken after new typing system')
# if isinstance(exp_type, TypeTuple):
@ -416,7 +436,7 @@ class OurVisitor:
if node.id in module.constant_defs:
cdef = module.constant_defs[node.id]
return ModuleConstantReference(cdef)
return VariableReference(cdef)
_raise_static_error(node, f'Undefined variable {node.id}')
@ -454,13 +474,13 @@ class OurVisitor:
if not isinstance(node.func.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if node.func.id in module.structs:
raise NotImplementedError('TODO: Broken after new type system')
# if node.func.id in module.structs:
# raise NotImplementedError('TODO: Broken after new type system')
# struct = module.structs[node.func.id]
# struct_constructor = StructConstructor(struct)
#
# func = module.functions[struct_constructor.name]
elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS:
if node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS:
if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
@ -533,61 +553,59 @@ class OurVisitor:
def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression:
raise NotImplementedError('Broken after new type system')
del module
del function
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if not node.value.id in our_locals:
_raise_static_error(node, f'Undefined variable {node.value.id}')
param = our_locals[node.value.id]
node_typ = param.type
if not isinstance(node_typ, TypeStruct):
_raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
member = node_typ.get_member(node.attr)
if member is None:
_raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}')
return AccessStructMember(
VariableReference(param),
member,
)
def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression:
raise NotImplementedError('TODO: Broken after new type system')
# del module
# del function
#
# if not isinstance(node.value, ast.Name):
# _raise_static_error(node, 'Must reference a name')
#
# if not isinstance(node.slice, ast.Index):
# _raise_static_error(node, 'Must subscript using an index')
#
# if not isinstance(node.ctx, ast.Load):
# _raise_static_error(node, 'Must be load context')
#
# varref: Union[ModuleConstantReference, VariableReference]
# if node.value.id in our_locals:
# param = our_locals[node.value.id]
# node_typ = param.type
# varref = VariableReference(param)
# elif node.value.id in module.constant_defs:
# constant_def = module.constant_defs[node.value.id]
# node_typ = constant_def.type
# varref = ModuleConstantReference(constant_def)
# else:
# if not node.value.id in our_locals:
# _raise_static_error(node, f'Undefined variable {node.value.id}')
#
# slice_expr = self.visit_Module_FunctionDef_expr(
# module, function, our_locals, node.slice.value,
# )
# param = our_locals[node.value.id]
#
# node_typ = param.type
# if not isinstance(node_typ, TypeStruct):
# _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
#
# member = node_typ.get_member(node.attr)
# if member is None:
# _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}')
#
# return AccessStructMember(
# VariableReference(param),
# member,
# )
def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression:
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name')
if not isinstance(node.slice, ast.Index):
_raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
varref: VariableReference
if node.value.id in our_locals:
param = our_locals[node.value.id]
varref = VariableReference(param)
elif node.value.id in module.constant_defs:
constant_def = module.constant_defs[node.value.id]
varref = VariableReference(constant_def)
else:
_raise_static_error(node, f'Undefined variable {node.value.id}')
slice_expr = self.visit_Module_FunctionDef_expr(
module, function, our_locals, node.slice.value,
)
return Subscript(varref, slice_expr)
# if isinstance(node_typ, TypeBytes):
# if isinstance(varref, ModuleConstantReference):
# raise NotImplementedError(f'{node} from module constant')

View File

@ -62,7 +62,7 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
return constant(ctx, inp)
if isinstance(inp, ourlang.VariableReference):
assert inp.variable.type_var is not None, inp
assert inp.variable.type_var is not None
return inp.variable.type_var
if isinstance(inp, ourlang.UnaryOp):
@ -112,13 +112,6 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
return inp.function.returns_type_var
if isinstance(inp, ourlang.ModuleConstantReference):
assert inp.definition.type_var is not None
inp.type_var = inp.definition.type_var
return inp.definition.type_var
raise NotImplementedError(expression, inp)
def function(ctx: Context, inp: ourlang.Function) -> None:

View File

@ -7,21 +7,18 @@ from ..helpers import Suite
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES)
def test_static_array_module_constant(type_):
def test_module_constant(type_):
code_py = f"""
CONSTANT: {type_}[3] = (24, 57, 80, )
@exported
def testEntry() -> {type_}:
return helper(CONSTANT)
def helper(array: {type_}[3]) -> {type_}:
return array[0] + array[1] + array[2]
return CONSTANT[0]
"""
result = Suite(code_py).run_code()
assert 161 == result.returned_value
assert 24 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@ -43,6 +40,25 @@ def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}:
assert 161 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES)
def test_function_call(type_):
code_py = f"""
CONSTANT: {type_}[3] = (24, 57, 80, )
@exported
def testEntry() -> {type_}:
return helper(CONSTANT)
def helper(array: {type_}[3]) -> {type_}:
return array[0] + array[1] + array[2]
"""
result = Suite(code_py).run_code()
assert 161 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
def test_static_array_constant_too_few_values():
code_py = """