From 564f00a419c05a5b05e42375dca6067652877adb Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Sat, 17 Sep 2022 20:13:16 +0200 Subject: [PATCH] Work on ripping out old type system --- phasm/codestyle.py | 69 +++++++-------------------- phasm/compiler.py | 108 ++++++++++++++++++++++++------------------ phasm/exceptions.py | 4 +- phasm/ourlang.py | 52 ++++++++++---------- phasm/parser.py | 69 ++++++--------------------- phasm/stdlib/alloc.py | 2 +- phasm/typer.py | 16 +++++++ phasm/typing.py | 42 +++++++++++++++- phasm/wasmeasy.py | 2 +- pylintrc | 2 +- 10 files changed, 184 insertions(+), 182 deletions(-) diff --git a/phasm/codestyle.py b/phasm/codestyle.py index bbb2eff..334b3eb 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -3,7 +3,7 @@ This module generates source code based on the parsed AST It's intented to be a "any color, as long as it's black" kind of renderer """ -from typing import Generator +from typing import Generator, Optional from . import ourlang from . import typing @@ -16,55 +16,17 @@ def phasm_render(inp: ourlang.Module) -> str: Statements = Generator[str, None, None] -def type_(inp: typing.TypeBase) -> str: +def type_var(inp: Optional[typing.TypeVar]) -> str: """ - Render: Type (name) + Render: type's name """ - if isinstance(inp, typing.TypeNone): - return 'None' + assert inp is not None, typing.ASSERTION_ERROR - if isinstance(inp, typing.TypeBool): - return 'bool' + mtyp = typing.simplify(inp) + if mtyp is None: + raise NotImplementedError(f'Rendering type {inp}') - if isinstance(inp, typing.TypeUInt8): - return 'u8' - - if isinstance(inp, typing.TypeUInt32): - return 'u32' - - if isinstance(inp, typing.TypeUInt64): - return 'u64' - - if isinstance(inp, typing.TypeInt32): - return 'i32' - - if isinstance(inp, typing.TypeInt64): - return 'i64' - - if isinstance(inp, typing.TypeFloat32): - return 'f32' - - if isinstance(inp, typing.TypeFloat64): - return 'f64' - - if isinstance(inp, typing.TypeBytes): - return 'bytes' - - if isinstance(inp, typing.TypeTuple): - mems = ', '.join( - type_(x.type) - for x in inp.members - ) - - return f'({mems}, )' - - if isinstance(inp, typing.TypeStaticArray): - return f'{type_(inp.member_type)}[{len(inp.members)}]' - - if isinstance(inp, typing.TypeStruct): - return inp.name - - raise NotImplementedError(type_, inp) + return mtyp def struct_definition(inp: typing.TypeStruct) -> str: """ @@ -72,7 +34,8 @@ def struct_definition(inp: typing.TypeStruct) -> str: """ result = f'class {inp.name}:\n' for mem in inp.members: - result += f' {mem.name}: {type_(mem.type)}\n' + raise NotImplementedError('Structs broken after new type system') + # result += f' {mem.name}: {type_(mem.type)}\n' return result @@ -80,7 +43,7 @@ def constant_definition(inp: ourlang.ModuleConstantDef) -> str: """ Render: Module Constant's definition """ - return f'{inp.name}: {type_(inp.type)} = {expression(inp.constant)}\n' + return f'{inp.name}: {type_var(inp.type_var)} = {expression(inp.constant)}\n' def expression(inp: ourlang.Expression) -> str: """ @@ -107,7 +70,11 @@ def expression(inp: ourlang.Expression) -> str: return f'{inp.operator}({expression(inp.right)})' if inp.operator == 'cast': - return f'{type_(inp.type)}({expression(inp.right)})' + mtyp = type_var(inp.type_var) + if mtyp is None: + raise NotImplementedError(f'Casting to type {inp.type_var}') + + return f'{mtyp}({expression(inp.right)})' return f'{inp.operator}{expression(inp.right)}' @@ -187,11 +154,11 @@ def function(inp: ourlang.Function) -> str: result += '@imported\n' args = ', '.join( - f'{p.name}: {type_(p.type)}' + f'{p.name}: {type_var(p.type_var)}' for p in inp.posonlyargs ) - result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n' + result += f'def {inp.name}({args}) -> {type_var(inp.returns_type_var)}:\n' if inp.imported: result += ' pass\n' diff --git a/phasm/compiler.py b/phasm/compiler.py index af6ae58..edd9066 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -1,6 +1,8 @@ """ This module contains the code to convert parsed Ourlang into WebAssembly code """ +from typing import List + import struct from . import codestyle @@ -132,7 +134,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: Compile: Any expression """ if isinstance(inp, ourlang.ConstantPrimitive): - assert inp.type_var is not None + assert inp.type_var is not None, typing.ASSERTION_ERROR stp = typing.simplify(inp.type_var) if stp is None: @@ -174,73 +176,80 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: expression(wgn, inp.left) expression(wgn, inp.right) - if isinstance(inp.type, typing.TypeUInt8): + assert inp.type_var is not None, typing.ASSERTION_ERROR + mtyp = typing.simplify(inp.type_var) + + if mtyp == 'u8': if operator := U8_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if isinstance(inp.type, typing.TypeUInt32): + if mtyp == 'u32': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return if operator := U32_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if isinstance(inp.type, typing.TypeUInt64): + if mtyp == 'u64': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return if operator := U64_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return - if isinstance(inp.type, typing.TypeInt32): + if mtyp == 'i32': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return if operator := I32_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if isinstance(inp.type, typing.TypeInt64): + if mtyp == 'i64': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return if operator := I64_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return - if isinstance(inp.type, typing.TypeFloat32): + if mtyp == 'f32': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'f32.{operator}') return - if isinstance(inp.type, typing.TypeFloat64): + if mtyp == 'f64': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'f64.{operator}') return - raise NotImplementedError(expression, inp.type, inp.operator) + raise NotImplementedError(expression, inp.type_var, inp.operator) if isinstance(inp, ourlang.UnaryOp): expression(wgn, inp.right) - if isinstance(inp.type, typing.TypeFloat32): + assert inp.type_var is not None, typing.ASSERTION_ERROR + mtyp = typing.simplify(inp.type_var) + + if mtyp == 'f32': if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: wgn.add_statement(f'f32.{inp.operator}') return - if isinstance(inp.type, typing.TypeFloat64): + if mtyp == 'f64': if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: wgn.add_statement(f'f64.{inp.operator}') return - if isinstance(inp.type, typing.TypeInt32): - if inp.operator == 'len': - if isinstance(inp.right.type, typing.TypeBytes): - wgn.i32.load() - return + # TODO: Broken after new type system + # if isinstance(inp.type, typing.TypeInt32): + # if inp.operator == 'len': + # if isinstance(inp.right.type, typing.TypeBytes): + # wgn.i32.load() + # return - if inp.operator == 'cast': - if isinstance(inp.type, typing.TypeUInt32) and isinstance(inp.right.type, typing.TypeUInt8): - # Nothing to do, you can use an u8 value as a u32 no problem - return + # if inp.operator == 'cast': + # if isinstance(inp.type, typing.TypeUInt32) and isinstance(inp.right.type, typing.TypeUInt8): + # # Nothing to do, you can use an u8 value as a u32 no problem + # return - raise NotImplementedError(expression, inp.type, inp.operator) + raise NotImplementedError(expression, inp.type_var, inp.operator) if isinstance(inp, ourlang.FunctionCall): for arg in inp.arguments: @@ -249,14 +258,15 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: wgn.add_statement('call', '${}'.format(inp.function.name)) return - if isinstance(inp, ourlang.AccessBytesIndex): - if not isinstance(inp.type, typing.TypeUInt8): - raise NotImplementedError(inp, inp.type) - - expression(wgn, inp.varref) - expression(wgn, inp.index) - wgn.call(stdlib_types.__subscript_bytes__) - return + # TODO: Broken after new type system + # if isinstance(inp, ourlang.AccessBytesIndex): + # if not isinstance(inp.type, typing.TypeUInt8): + # raise NotImplementedError(inp, inp.type) + # + # expression(wgn, inp.varref) + # expression(wgn, inp.index) + # wgn.call(stdlib_types.__subscript_bytes__) + # return if isinstance(inp, ourlang.AccessStructMember): mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) @@ -305,27 +315,29 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: return if isinstance(inp, ourlang.ModuleConstantReference): - if isinstance(inp.type, typing.TypeTuple): - assert isinstance(inp.definition.constant, ourlang.ConstantTuple) - assert inp.definition.data_block is not None, 'Combined values are memory stored' - assert inp.definition.data_block.address is not None, 'Value not allocated' - wgn.i32.const(inp.definition.data_block.address) - return - - if isinstance(inp.type, typing.TypeStaticArray): - assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray) - assert inp.definition.data_block is not None, 'Combined values are memory stored' - assert inp.definition.data_block.address is not None, 'Value not allocated' - wgn.i32.const(inp.definition.data_block.address) - return + # FIXME: Tuple / Static Array broken after new type system + # if isinstance(inp.type, typing.TypeTuple): + # assert isinstance(inp.definition.constant, ourlang.ConstantTuple) + # assert inp.definition.data_block is not None, 'Combined values are memory stored' + # assert inp.definition.data_block.address is not None, 'Value not allocated' + # wgn.i32.const(inp.definition.data_block.address) + # return + # + # if isinstance(inp.type, typing.TypeStaticArray): + # assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray) + # assert inp.definition.data_block is not None, 'Combined values are memory stored' + # assert inp.definition.data_block.address is not None, 'Value not allocated' + # wgn.i32.const(inp.definition.data_block.address) + # return assert inp.definition.data_block is None, 'Primitives are not memory stored' - mtyp = LOAD_STORE_TYPE_MAP.get(inp.type.__class__) + assert inp.type_var is not None, typing.ASSERTION_ERROR + mtyp = typing.simplify(inp.type_var) if mtyp is None: # In the future might extend this by having structs or tuples # as members of struct or tuples - raise NotImplementedError(expression, inp, inp.type) + raise NotImplementedError(expression, inp, inp.type_var) expression(wgn, inp.definition.constant) return @@ -336,13 +348,15 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: """ Compile: Fold expression """ - assert inp.base.type_var is not None + assert inp.base.type_var is not None, typing.ASSERTION_ERROR mtyp = typing.simplify(inp.base.type_var) if mtyp is None: # In the future might extend this by having structs or tuples # as members of struct or tuples raise NotImplementedError(expression, inp, inp.base) + raise NotImplementedError('TODO: Broken after new type system') + if inp.iter.type.__class__.__name__ != 'TypeBytes': raise NotImplementedError(expression, inp, inp.iter.type) @@ -563,7 +577,9 @@ def module_data(inp: ourlang.ModuleData) -> bytes: for block in inp.blocks: block.address = unalloc_ptr + 4 # 4 bytes for allocator header - data_list = [] + data_list: List[bytes] = [] + + raise NotImplementedError('Broken after new type system') for constant in block.data: if isinstance(constant, ourlang.ConstantUInt8): diff --git a/phasm/exceptions.py b/phasm/exceptions.py index abbcdd2..77c75e7 100644 --- a/phasm/exceptions.py +++ b/phasm/exceptions.py @@ -8,4 +8,6 @@ class StaticError(Exception): """ class TypingError(Exception): - pass + """ + An error found during the typing phase + """ diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 30a2527..ebc13d7 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -1,7 +1,7 @@ """ Contains the syntax tree for ourlang """ -from typing import Dict, List, Tuple, Optional, Union +from typing import Dict, List, Optional, Union import enum @@ -13,7 +13,6 @@ WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', ) from .typing import ( TypeBase, TypeNone, - TypeBool, TypeUInt8, TypeUInt32, TypeUInt64, TypeInt32, TypeInt64, TypeFloat32, TypeFloat64, @@ -29,13 +28,11 @@ class Expression: """ An expression within a statement """ - __slots__ = ('type', 'type_var', ) + __slots__ = ('type_var', ) - type: TypeBase type_var: Optional[TypeVar] - def __init__(self, type_: TypeBase) -> None: - self.type = type_ + def __init__(self) -> None: self.type_var = None class Constant(Expression): @@ -53,6 +50,7 @@ class ConstantPrimitive(Constant): value: Union[int, float] def __init__(self, value: Union[int, float]) -> None: + super().__init__() self.value = value class ConstantTuple(Constant): @@ -63,8 +61,8 @@ class ConstantTuple(Constant): value: List[ConstantPrimitive] - def __init__(self, type_: TypeTuple, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples? - super().__init__(type_) + def __init__(self, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples? + super().__init__() self.value = value class ConstantStaticArray(Constant): @@ -75,8 +73,8 @@ class ConstantStaticArray(Constant): value: List[ConstantPrimitive] - def __init__(self, type_: TypeStaticArray, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays? - super().__init__(type_) + def __init__(self, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays? + super().__init__() self.value = value class VariableReference(Expression): @@ -87,8 +85,8 @@ class VariableReference(Expression): variable: 'FunctionParam' # also possibly local - def __init__(self, type_: TypeBase, variable: 'FunctionParam') -> None: - super().__init__(type_) + def __init__(self, variable: 'FunctionParam') -> None: + super().__init__() self.variable = variable class UnaryOp(Expression): @@ -100,8 +98,8 @@ class UnaryOp(Expression): operator: str right: Expression - def __init__(self, type_: TypeBase, operator: str, right: Expression) -> None: - super().__init__(type_) + def __init__(self, operator: str, right: Expression) -> None: + super().__init__() self.operator = operator self.right = right @@ -116,8 +114,8 @@ class BinaryOp(Expression): left: Expression right: Expression - def __init__(self, type_: TypeBase, operator: str, left: Expression, right: Expression) -> None: - super().__init__(type_) + def __init__(self, operator: str, left: Expression, right: Expression) -> None: + super().__init__() self.operator = operator self.left = left @@ -133,7 +131,7 @@ class FunctionCall(Expression): arguments: List[Expression] def __init__(self, function: 'Function') -> None: - super().__init__(function.returns) + super().__init__() self.function = function self.arguments = [] @@ -147,8 +145,8 @@ class AccessBytesIndex(Expression): varref: VariableReference index: Expression - def __init__(self, type_: TypeBase, varref: VariableReference, index: Expression) -> None: - super().__init__(type_) + def __init__(self, varref: VariableReference, index: Expression) -> None: + super().__init__() self.varref = varref self.index = index @@ -163,7 +161,7 @@ class AccessStructMember(Expression): member: TypeStructMember def __init__(self, varref: VariableReference, member: TypeStructMember) -> None: - super().__init__(member.type) + super().__init__() self.varref = varref self.member = member @@ -178,7 +176,7 @@ class AccessTupleMember(Expression): member: TypeTupleMember def __init__(self, varref: VariableReference, member: TypeTupleMember, ) -> None: - super().__init__(member.type) + super().__init__() self.varref = varref self.member = member @@ -194,7 +192,7 @@ class AccessStaticArrayMember(Expression): member: Union[Expression, TypeStaticArrayMember] def __init__(self, varref: Union['ModuleConstantReference', VariableReference], static_array: TypeStaticArray, member: Union[TypeStaticArrayMember, Expression], ) -> None: - super().__init__(static_array.member_type) + super().__init__() self.varref = varref self.static_array = static_array @@ -218,13 +216,12 @@ class Fold(Expression): def __init__( self, - type_: TypeBase, dir_: Direction, func: 'Function', base: Expression, iter_: Expression, ) -> None: - super().__init__(type_) + super().__init__() self.dir = dir_ self.func = func @@ -239,8 +236,8 @@ class ModuleConstantReference(Expression): definition: 'ModuleConstantDef' - def __init__(self, type_: TypeBase, definition: 'ModuleConstantDef') -> None: - super().__init__(type_) + def __init__(self, definition: 'ModuleConstantDef') -> None: + super().__init__() self.definition = definition class Statement: @@ -280,6 +277,9 @@ class StatementIf(Statement): self.else_statements = [] class FunctionParam: + """ + A parameter for a Function + """ __slots__ = ('name', 'type', 'type_var', ) name: str diff --git a/phasm/parser.py b/phasm/parser.py index 8b1b3a2..d8e94af 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -7,13 +7,6 @@ import ast from .typing import ( TypeBase, - TypeUInt8, - TypeUInt32, - TypeUInt64, - TypeInt32, - TypeInt64, - TypeFloat32, - TypeFloat64, TypeBytes, TypeStruct, TypeStructMember, @@ -34,7 +27,6 @@ from .ourlang import ( Expression, AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember, BinaryOp, - Constant, ConstantPrimitive, ConstantTuple, ConstantStaticArray, FunctionCall, @@ -242,7 +234,7 @@ class OurVisitor: node.target.id, node.lineno, exp_type, - ConstantTuple(exp_type, tuple_data), + ConstantTuple(tuple_data), data_block, ) @@ -270,7 +262,7 @@ class OurVisitor: node.target.id, node.lineno, exp_type, - ConstantStaticArray(exp_type, static_array_data), + ConstantStaticArray(static_array_data), data_block, ) @@ -359,7 +351,6 @@ class OurVisitor: # e.g. you can do `"hello" * 3` with the code below (yet) return BinaryOp( - exp_type, operator, self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right), @@ -374,7 +365,6 @@ class OurVisitor: raise NotImplementedError(f'Operator {node.op}') return UnaryOp( - exp_type, operator, self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand), ) @@ -396,7 +386,6 @@ class OurVisitor: # e.g. you can do `"hello" * 3` with the code below (yet) return BinaryOp( - exp_type, operator, self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]), @@ -412,12 +401,12 @@ class OurVisitor: if isinstance(node, ast.Attribute): return self.visit_Module_FunctionDef_Attribute( - module, function, our_locals, exp_type, node, + module, function, our_locals, node, ) if isinstance(node, ast.Subscript): return self.visit_Module_FunctionDef_Subscript( - module, function, our_locals, exp_type, node, + module, function, our_locals, node, ) if isinstance(node, ast.Name): @@ -426,21 +415,17 @@ class OurVisitor: if node.id in our_locals: param = our_locals[node.id] - if exp_type != param.type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(param.type)}') - - return VariableReference(param.type, param) + return VariableReference(param) if node.id in module.constant_defs: cdef = module.constant_defs[node.id] - if exp_type != cdef.type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(cdef.type)}') - - return ModuleConstantReference(exp_type, cdef) + return ModuleConstantReference(cdef) _raise_static_error(node, f'Undefined variable {node.id}') if isinstance(node, ast.Tuple): + raise NotImplementedError('TODO: Broken after new type system') + if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') @@ -478,40 +463,28 @@ class OurVisitor: func = module.functions[struct_constructor.name] elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: - if not isinstance(exp_type, (TypeFloat32, TypeFloat64, )): - _raise_static_error(node, f'Cannot make {node.func.id} result in {codestyle.type_(exp_type)}') - if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') return UnaryOp( - exp_type, 'sqrt', self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), ) elif node.func.id == 'u32': - if not isinstance(exp_type, TypeUInt32): - _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') - if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') # FIXME: This is a stub, proper casting is todo return UnaryOp( - exp_type, 'cast', self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['u8'], node.args[0]), ) elif node.func.id == 'len': - if not isinstance(exp_type, TypeInt32): - _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') - if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') return UnaryOp( - exp_type, 'len', self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]), ) @@ -536,6 +509,8 @@ class OurVisitor: if 2 != len(func.posonlyargs): _raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given') + raise NotImplementedError('TODO: Broken after new type system') + if exp_type.__class__ != func.returns.__class__: _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') @@ -546,7 +521,6 @@ class OurVisitor: _raise_static_error(node, 'Only folding over bytes (u8) is supported at this time') return Fold( - exp_type, Fold.Direction.LEFT, func, self.visit_Module_FunctionDef_expr(module, function, our_locals, func.returns, node.args[1]), @@ -571,7 +545,7 @@ class OurVisitor: ) return result - def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Attribute) -> Expression: + def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression: del module del function @@ -594,15 +568,12 @@ class OurVisitor: if member is None: _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}') - if exp_type != member.type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}') - return AccessStructMember( - VariableReference(node_typ, param), + VariableReference(param), member, ) - def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Subscript) -> Expression: + def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression: if not isinstance(node.value, ast.Name): _raise_static_error(node, 'Must reference a name') @@ -616,11 +587,11 @@ class OurVisitor: if node.value.id in our_locals: param = our_locals[node.value.id] node_typ = param.type - varref = VariableReference(param.type, param) + varref = VariableReference(param) elif node.value.id in module.constant_defs: constant_def = module.constant_defs[node.value.id] node_typ = constant_def.type - varref = ModuleConstantReference(node_typ, constant_def) + varref = ModuleConstantReference(constant_def) else: _raise_static_error(node, f'Undefined variable {node.value.id}') @@ -629,15 +600,10 @@ class OurVisitor: ) if isinstance(node_typ, TypeBytes): - t_u8 = module.types['u8'] - if exp_type != t_u8: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{codestyle.expression(slice_expr)}] is actually {codestyle.type_(t_u8)}') - if isinstance(varref, ModuleConstantReference): raise NotImplementedError(f'{node} from module constant') return AccessBytesIndex( - t_u8, varref, slice_expr, ) @@ -655,8 +621,6 @@ class OurVisitor: _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') tuple_member = node_typ.members[idx] - if exp_type != tuple_member.type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(tuple_member.type)}') if isinstance(varref, ModuleConstantReference): raise NotImplementedError(f'{node} from module constant') @@ -667,9 +631,6 @@ class OurVisitor: ) if isinstance(node_typ, TypeStaticArray): - if exp_type != node_typ.member_type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}') - if not isinstance(slice_expr, ConstantPrimitive): return AccessStaticArrayMember( varref, diff --git a/phasm/stdlib/alloc.py b/phasm/stdlib/alloc.py index 2761bfb..8c5742d 100644 --- a/phasm/stdlib/alloc.py +++ b/phasm/stdlib/alloc.py @@ -26,7 +26,7 @@ def __find_free_block__(g: Generator, alloc_size: i32) -> i32: g.i32.const(0) g.return_() - del alloc_size # TODO + del alloc_size # TODO: Actual implement using a previously freed block g.unreachable() return i32('return') # To satisfy mypy diff --git a/phasm/typer.py b/phasm/typer.py index de55ef3..1230d01 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -66,19 +66,27 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': return inp.variable.type_var if isinstance(inp, ourlang.UnaryOp): + # TODO: Simplified version if inp.operator not in ('sqrt', ): raise NotImplementedError(expression, inp, inp.operator) right = expression(ctx, inp.right) + + inp.type_var = right + return right if isinstance(inp, ourlang.BinaryOp): + # TODO: Simplified version if inp.operator not in ('+', '-', '*', '|', '&', '^'): raise NotImplementedError(expression, inp, inp.operator) left = expression(ctx, inp.left) right = expression(ctx, inp.right) ctx.unify(left, right) + + inp.type_var = left + return left if isinstance(inp, ourlang.FunctionCall): @@ -94,6 +102,9 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': if isinstance(inp, ourlang.ModuleConstantReference): assert inp.definition.type_var is not None + + inp.type_var = inp.definition.type_var + return inp.definition.type_var raise NotImplementedError(expression, inp) @@ -133,30 +144,35 @@ def _convert_old_type(ctx: Context, inp: typing.TypeBase, location: str) -> Type result = ctx.new_var() if isinstance(inp, typing.TypeUInt8): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) result.add_constraint(TypeConstraintSigned(False)) result.add_location(location) return result if isinstance(inp, typing.TypeUInt32): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintSigned(False)) result.add_location(location) return result if isinstance(inp, typing.TypeUInt64): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintSigned(False)) result.add_location(location) return result if isinstance(inp, typing.TypeInt32): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintSigned(True)) result.add_location(location) return result if isinstance(inp, typing.TypeInt64): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintSigned(True)) result.add_location(location) diff --git a/phasm/typing.py b/phasm/typing.py index 0a8982b..94bb602 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -207,23 +207,45 @@ class TypeStruct(TypeBase): ## NEW STUFF BELOW +# This error can also mean that the type somewhere forgot to write a type +# back to the AST. If so, we need to fix the typer. +ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method' + + class TypingNarrowProtoError(TypingError): - pass + """ + A proto error when trying to narrow two types + + This gets turned into a TypingNarrowError by the unify method + """ + # FIXME: Use consistent naming for unify / narrow / entangle class TypingNarrowError(TypingError): + """ + An error when trying to unify two Type Variables + """ def __init__(self, l: 'TypeVar', r: 'TypeVar', msg: str) -> None: super().__init__( f'Cannot narrow types {l} and {r}: {msg}' ) class TypeConstraintBase: + """ + Base class for classes implementing a contraint on a type + """ def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase': raise NotImplementedError('narrow', self, other) class TypeConstraintPrimitive(TypeConstraintBase): + """ + This contraint on a type defines its primitive shape + """ __slots__ = ('primitive', ) class Primitive(enum.Enum): + """ + The primitive ID + """ INT = 0 FLOAT = 1 @@ -245,6 +267,10 @@ class TypeConstraintPrimitive(TypeConstraintBase): return f'Primitive={self.primitive.name}' class TypeConstraintSigned(TypeConstraintBase): + """ + Contraint on whether a signed value can be used or not, or whether + a value can be used in a signed expression + """ __slots__ = ('signed', ) signed: Optional[bool] @@ -270,6 +296,9 @@ class TypeConstraintSigned(TypeConstraintBase): return f'Signed={self.signed}' class TypeConstraintBitWidth(TypeConstraintBase): + """ + Contraint on how many bits an expression has or can possibly have + """ __slots__ = ('minb', 'maxb', ) minb: int @@ -301,6 +330,10 @@ class TypeConstraintBitWidth(TypeConstraintBase): return f'BitWidth={self.minb}..{self.maxb}' class TypeVar: + """ + A type variable + """ + # FIXME: Explain the type system __slots__ = ('ctx', 'ctx_id', ) ctx: 'Context' @@ -331,6 +364,9 @@ class TypeVar: ) class Context: + """ + The context for a collection of type variables + """ def __init__(self) -> None: # Variables are unified (or entangled, if you will) # that means that each TypeVar within a context has an ID, @@ -399,6 +435,10 @@ class Context: del self.var_locations[r_ctx_id] def simplify(inp: TypeVar) -> Optional[str]: + """ + Simplifies a TypeVar into a string that wasm can work with + and users can recognize + """ tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive) tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth) tc_sign = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintSigned) diff --git a/phasm/wasmeasy.py b/phasm/wasmeasy.py index d0cf358..01c5d0c 100644 --- a/phasm/wasmeasy.py +++ b/phasm/wasmeasy.py @@ -1,7 +1,7 @@ """ Helper functions to quickly generate WASM code """ -from typing import Any, Dict, List, Optional, Type +from typing import List, Optional import functools diff --git a/pylintrc b/pylintrc index 3759e6b..82948bb 100644 --- a/pylintrc +++ b/pylintrc @@ -1,5 +1,5 @@ [MASTER] -disable=C0103,C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223 +disable=C0103,C0122,R0902,R0903,R0911,R0912,R0913,R0915,R1710,W0223 max-line-length=180