Work on ripping out old type system

This commit is contained in:
Johan B.W. de Vries 2022-09-17 20:13:16 +02:00
parent 58f74d3e1d
commit 564f00a419
10 changed files with 184 additions and 182 deletions

View File

@ -3,7 +3,7 @@ This module generates source code based on the parsed AST
It's intented to be a "any color, as long as it's black" kind of renderer It's intented to be a "any color, as long as it's black" kind of renderer
""" """
from typing import Generator from typing import Generator, Optional
from . import ourlang from . import ourlang
from . import typing from . import typing
@ -16,55 +16,17 @@ def phasm_render(inp: ourlang.Module) -> str:
Statements = Generator[str, None, None] Statements = Generator[str, None, None]
def type_(inp: typing.TypeBase) -> str: def type_var(inp: Optional[typing.TypeVar]) -> str:
""" """
Render: Type (name) Render: type's name
""" """
if isinstance(inp, typing.TypeNone): assert inp is not None, typing.ASSERTION_ERROR
return 'None'
if isinstance(inp, typing.TypeBool): mtyp = typing.simplify(inp)
return 'bool' if mtyp is None:
raise NotImplementedError(f'Rendering type {inp}')
if isinstance(inp, typing.TypeUInt8): return mtyp
return 'u8'
if isinstance(inp, typing.TypeUInt32):
return 'u32'
if isinstance(inp, typing.TypeUInt64):
return 'u64'
if isinstance(inp, typing.TypeInt32):
return 'i32'
if isinstance(inp, typing.TypeInt64):
return 'i64'
if isinstance(inp, typing.TypeFloat32):
return 'f32'
if isinstance(inp, typing.TypeFloat64):
return 'f64'
if isinstance(inp, typing.TypeBytes):
return 'bytes'
if isinstance(inp, typing.TypeTuple):
mems = ', '.join(
type_(x.type)
for x in inp.members
)
return f'({mems}, )'
if isinstance(inp, typing.TypeStaticArray):
return f'{type_(inp.member_type)}[{len(inp.members)}]'
if isinstance(inp, typing.TypeStruct):
return inp.name
raise NotImplementedError(type_, inp)
def struct_definition(inp: typing.TypeStruct) -> str: def struct_definition(inp: typing.TypeStruct) -> str:
""" """
@ -72,7 +34,8 @@ def struct_definition(inp: typing.TypeStruct) -> str:
""" """
result = f'class {inp.name}:\n' result = f'class {inp.name}:\n'
for mem in inp.members: for mem in inp.members:
result += f' {mem.name}: {type_(mem.type)}\n' raise NotImplementedError('Structs broken after new type system')
# result += f' {mem.name}: {type_(mem.type)}\n'
return result return result
@ -80,7 +43,7 @@ def constant_definition(inp: ourlang.ModuleConstantDef) -> str:
""" """
Render: Module Constant's definition Render: Module Constant's definition
""" """
return f'{inp.name}: {type_(inp.type)} = {expression(inp.constant)}\n' return f'{inp.name}: {type_var(inp.type_var)} = {expression(inp.constant)}\n'
def expression(inp: ourlang.Expression) -> str: def expression(inp: ourlang.Expression) -> str:
""" """
@ -107,7 +70,11 @@ def expression(inp: ourlang.Expression) -> str:
return f'{inp.operator}({expression(inp.right)})' return f'{inp.operator}({expression(inp.right)})'
if inp.operator == 'cast': if inp.operator == 'cast':
return f'{type_(inp.type)}({expression(inp.right)})' mtyp = type_var(inp.type_var)
if mtyp is None:
raise NotImplementedError(f'Casting to type {inp.type_var}')
return f'{mtyp}({expression(inp.right)})'
return f'{inp.operator}{expression(inp.right)}' return f'{inp.operator}{expression(inp.right)}'
@ -187,11 +154,11 @@ def function(inp: ourlang.Function) -> str:
result += '@imported\n' result += '@imported\n'
args = ', '.join( args = ', '.join(
f'{p.name}: {type_(p.type)}' f'{p.name}: {type_var(p.type_var)}'
for p in inp.posonlyargs for p in inp.posonlyargs
) )
result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n' result += f'def {inp.name}({args}) -> {type_var(inp.returns_type_var)}:\n'
if inp.imported: if inp.imported:
result += ' pass\n' result += ' pass\n'

View File

@ -1,6 +1,8 @@
""" """
This module contains the code to convert parsed Ourlang into WebAssembly code This module contains the code to convert parsed Ourlang into WebAssembly code
""" """
from typing import List
import struct import struct
from . import codestyle from . import codestyle
@ -132,7 +134,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
Compile: Any expression Compile: Any expression
""" """
if isinstance(inp, ourlang.ConstantPrimitive): if isinstance(inp, ourlang.ConstantPrimitive):
assert inp.type_var is not None assert inp.type_var is not None, typing.ASSERTION_ERROR
stp = typing.simplify(inp.type_var) stp = typing.simplify(inp.type_var)
if stp is None: if stp is None:
@ -174,73 +176,80 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
expression(wgn, inp.left) expression(wgn, inp.left)
expression(wgn, inp.right) expression(wgn, inp.right)
if isinstance(inp.type, typing.TypeUInt8): assert inp.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.type_var)
if mtyp == 'u8':
if operator := U8_OPERATOR_MAP.get(inp.operator, None): if operator := U8_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if isinstance(inp.type, typing.TypeUInt32): if mtyp == 'u32':
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if operator := U32_OPERATOR_MAP.get(inp.operator, None): if operator := U32_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if isinstance(inp.type, typing.TypeUInt64): if mtyp == 'u64':
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}') wgn.add_statement(f'i64.{operator}')
return return
if operator := U64_OPERATOR_MAP.get(inp.operator, None): if operator := U64_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}') wgn.add_statement(f'i64.{operator}')
return return
if isinstance(inp.type, typing.TypeInt32): if mtyp == 'i32':
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if operator := I32_OPERATOR_MAP.get(inp.operator, None): if operator := I32_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if isinstance(inp.type, typing.TypeInt64): if mtyp == 'i64':
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}') wgn.add_statement(f'i64.{operator}')
return return
if operator := I64_OPERATOR_MAP.get(inp.operator, None): if operator := I64_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}') wgn.add_statement(f'i64.{operator}')
return return
if isinstance(inp.type, typing.TypeFloat32): if mtyp == 'f32':
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'f32.{operator}') wgn.add_statement(f'f32.{operator}')
return return
if isinstance(inp.type, typing.TypeFloat64): if mtyp == 'f64':
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'f64.{operator}') wgn.add_statement(f'f64.{operator}')
return return
raise NotImplementedError(expression, inp.type, inp.operator) raise NotImplementedError(expression, inp.type_var, inp.operator)
if isinstance(inp, ourlang.UnaryOp): if isinstance(inp, ourlang.UnaryOp):
expression(wgn, inp.right) expression(wgn, inp.right)
if isinstance(inp.type, typing.TypeFloat32): assert inp.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.type_var)
if mtyp == 'f32':
if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS:
wgn.add_statement(f'f32.{inp.operator}') wgn.add_statement(f'f32.{inp.operator}')
return return
if isinstance(inp.type, typing.TypeFloat64): if mtyp == 'f64':
if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS:
wgn.add_statement(f'f64.{inp.operator}') wgn.add_statement(f'f64.{inp.operator}')
return return
if isinstance(inp.type, typing.TypeInt32): # TODO: Broken after new type system
if inp.operator == 'len': # if isinstance(inp.type, typing.TypeInt32):
if isinstance(inp.right.type, typing.TypeBytes): # if inp.operator == 'len':
wgn.i32.load() # if isinstance(inp.right.type, typing.TypeBytes):
return # wgn.i32.load()
# return
if inp.operator == 'cast': # if inp.operator == 'cast':
if isinstance(inp.type, typing.TypeUInt32) and isinstance(inp.right.type, typing.TypeUInt8): # if isinstance(inp.type, typing.TypeUInt32) and isinstance(inp.right.type, typing.TypeUInt8):
# Nothing to do, you can use an u8 value as a u32 no problem # # Nothing to do, you can use an u8 value as a u32 no problem
return # return
raise NotImplementedError(expression, inp.type, inp.operator) raise NotImplementedError(expression, inp.type_var, inp.operator)
if isinstance(inp, ourlang.FunctionCall): if isinstance(inp, ourlang.FunctionCall):
for arg in inp.arguments: for arg in inp.arguments:
@ -249,14 +258,15 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
wgn.add_statement('call', '${}'.format(inp.function.name)) wgn.add_statement('call', '${}'.format(inp.function.name))
return return
if isinstance(inp, ourlang.AccessBytesIndex): # TODO: Broken after new type system
if not isinstance(inp.type, typing.TypeUInt8): # if isinstance(inp, ourlang.AccessBytesIndex):
raise NotImplementedError(inp, inp.type) # if not isinstance(inp.type, typing.TypeUInt8):
# raise NotImplementedError(inp, inp.type)
expression(wgn, inp.varref) #
expression(wgn, inp.index) # expression(wgn, inp.varref)
wgn.call(stdlib_types.__subscript_bytes__) # expression(wgn, inp.index)
return # wgn.call(stdlib_types.__subscript_bytes__)
# return
if isinstance(inp, ourlang.AccessStructMember): if isinstance(inp, ourlang.AccessStructMember):
mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__)
@ -305,27 +315,29 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
return return
if isinstance(inp, ourlang.ModuleConstantReference): if isinstance(inp, ourlang.ModuleConstantReference):
if isinstance(inp.type, typing.TypeTuple): # FIXME: Tuple / Static Array broken after new type system
assert isinstance(inp.definition.constant, ourlang.ConstantTuple) # if isinstance(inp.type, typing.TypeTuple):
assert inp.definition.data_block is not None, 'Combined values are memory stored' # assert isinstance(inp.definition.constant, ourlang.ConstantTuple)
assert inp.definition.data_block.address is not None, 'Value not allocated' # assert inp.definition.data_block is not None, 'Combined values are memory stored'
wgn.i32.const(inp.definition.data_block.address) # assert inp.definition.data_block.address is not None, 'Value not allocated'
return # wgn.i32.const(inp.definition.data_block.address)
# return
if isinstance(inp.type, typing.TypeStaticArray): #
assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray) # if isinstance(inp.type, typing.TypeStaticArray):
assert inp.definition.data_block is not None, 'Combined values are memory stored' # assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray)
assert inp.definition.data_block.address is not None, 'Value not allocated' # assert inp.definition.data_block is not None, 'Combined values are memory stored'
wgn.i32.const(inp.definition.data_block.address) # assert inp.definition.data_block.address is not None, 'Value not allocated'
return # wgn.i32.const(inp.definition.data_block.address)
# return
assert inp.definition.data_block is None, 'Primitives are not memory stored' assert inp.definition.data_block is None, 'Primitives are not memory stored'
mtyp = LOAD_STORE_TYPE_MAP.get(inp.type.__class__) assert inp.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.type_var)
if mtyp is None: if mtyp is None:
# In the future might extend this by having structs or tuples # In the future might extend this by having structs or tuples
# as members of struct or tuples # as members of struct or tuples
raise NotImplementedError(expression, inp, inp.type) raise NotImplementedError(expression, inp, inp.type_var)
expression(wgn, inp.definition.constant) expression(wgn, inp.definition.constant)
return return
@ -336,13 +348,15 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
""" """
Compile: Fold expression Compile: Fold expression
""" """
assert inp.base.type_var is not None assert inp.base.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.base.type_var) mtyp = typing.simplify(inp.base.type_var)
if mtyp is None: if mtyp is None:
# In the future might extend this by having structs or tuples # In the future might extend this by having structs or tuples
# as members of struct or tuples # as members of struct or tuples
raise NotImplementedError(expression, inp, inp.base) raise NotImplementedError(expression, inp, inp.base)
raise NotImplementedError('TODO: Broken after new type system')
if inp.iter.type.__class__.__name__ != 'TypeBytes': if inp.iter.type.__class__.__name__ != 'TypeBytes':
raise NotImplementedError(expression, inp, inp.iter.type) raise NotImplementedError(expression, inp, inp.iter.type)
@ -563,7 +577,9 @@ def module_data(inp: ourlang.ModuleData) -> bytes:
for block in inp.blocks: for block in inp.blocks:
block.address = unalloc_ptr + 4 # 4 bytes for allocator header block.address = unalloc_ptr + 4 # 4 bytes for allocator header
data_list = [] data_list: List[bytes] = []
raise NotImplementedError('Broken after new type system')
for constant in block.data: for constant in block.data:
if isinstance(constant, ourlang.ConstantUInt8): if isinstance(constant, ourlang.ConstantUInt8):

View File

@ -8,4 +8,6 @@ class StaticError(Exception):
""" """
class TypingError(Exception): class TypingError(Exception):
pass """
An error found during the typing phase
"""

View File

@ -1,7 +1,7 @@
""" """
Contains the syntax tree for ourlang Contains the syntax tree for ourlang
""" """
from typing import Dict, List, Tuple, Optional, Union from typing import Dict, List, Optional, Union
import enum import enum
@ -13,7 +13,6 @@ WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', )
from .typing import ( from .typing import (
TypeBase, TypeBase,
TypeNone, TypeNone,
TypeBool,
TypeUInt8, TypeUInt32, TypeUInt64, TypeUInt8, TypeUInt32, TypeUInt64,
TypeInt32, TypeInt64, TypeInt32, TypeInt64,
TypeFloat32, TypeFloat64, TypeFloat32, TypeFloat64,
@ -29,13 +28,11 @@ class Expression:
""" """
An expression within a statement An expression within a statement
""" """
__slots__ = ('type', 'type_var', ) __slots__ = ('type_var', )
type: TypeBase
type_var: Optional[TypeVar] type_var: Optional[TypeVar]
def __init__(self, type_: TypeBase) -> None: def __init__(self) -> None:
self.type = type_
self.type_var = None self.type_var = None
class Constant(Expression): class Constant(Expression):
@ -53,6 +50,7 @@ class ConstantPrimitive(Constant):
value: Union[int, float] value: Union[int, float]
def __init__(self, value: Union[int, float]) -> None: def __init__(self, value: Union[int, float]) -> None:
super().__init__()
self.value = value self.value = value
class ConstantTuple(Constant): class ConstantTuple(Constant):
@ -63,8 +61,8 @@ class ConstantTuple(Constant):
value: List[ConstantPrimitive] value: List[ConstantPrimitive]
def __init__(self, type_: TypeTuple, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples? def __init__(self, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples?
super().__init__(type_) super().__init__()
self.value = value self.value = value
class ConstantStaticArray(Constant): class ConstantStaticArray(Constant):
@ -75,8 +73,8 @@ class ConstantStaticArray(Constant):
value: List[ConstantPrimitive] value: List[ConstantPrimitive]
def __init__(self, type_: TypeStaticArray, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays? def __init__(self, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays?
super().__init__(type_) super().__init__()
self.value = value self.value = value
class VariableReference(Expression): class VariableReference(Expression):
@ -87,8 +85,8 @@ class VariableReference(Expression):
variable: 'FunctionParam' # also possibly local variable: 'FunctionParam' # also possibly local
def __init__(self, type_: TypeBase, variable: 'FunctionParam') -> None: def __init__(self, variable: 'FunctionParam') -> None:
super().__init__(type_) super().__init__()
self.variable = variable self.variable = variable
class UnaryOp(Expression): class UnaryOp(Expression):
@ -100,8 +98,8 @@ class UnaryOp(Expression):
operator: str operator: str
right: Expression right: Expression
def __init__(self, type_: TypeBase, operator: str, right: Expression) -> None: def __init__(self, operator: str, right: Expression) -> None:
super().__init__(type_) super().__init__()
self.operator = operator self.operator = operator
self.right = right self.right = right
@ -116,8 +114,8 @@ class BinaryOp(Expression):
left: Expression left: Expression
right: Expression right: Expression
def __init__(self, type_: TypeBase, operator: str, left: Expression, right: Expression) -> None: def __init__(self, operator: str, left: Expression, right: Expression) -> None:
super().__init__(type_) super().__init__()
self.operator = operator self.operator = operator
self.left = left self.left = left
@ -133,7 +131,7 @@ class FunctionCall(Expression):
arguments: List[Expression] arguments: List[Expression]
def __init__(self, function: 'Function') -> None: def __init__(self, function: 'Function') -> None:
super().__init__(function.returns) super().__init__()
self.function = function self.function = function
self.arguments = [] self.arguments = []
@ -147,8 +145,8 @@ class AccessBytesIndex(Expression):
varref: VariableReference varref: VariableReference
index: Expression index: Expression
def __init__(self, type_: TypeBase, varref: VariableReference, index: Expression) -> None: def __init__(self, varref: VariableReference, index: Expression) -> None:
super().__init__(type_) super().__init__()
self.varref = varref self.varref = varref
self.index = index self.index = index
@ -163,7 +161,7 @@ class AccessStructMember(Expression):
member: TypeStructMember member: TypeStructMember
def __init__(self, varref: VariableReference, member: TypeStructMember) -> None: def __init__(self, varref: VariableReference, member: TypeStructMember) -> None:
super().__init__(member.type) super().__init__()
self.varref = varref self.varref = varref
self.member = member self.member = member
@ -178,7 +176,7 @@ class AccessTupleMember(Expression):
member: TypeTupleMember member: TypeTupleMember
def __init__(self, varref: VariableReference, member: TypeTupleMember, ) -> None: def __init__(self, varref: VariableReference, member: TypeTupleMember, ) -> None:
super().__init__(member.type) super().__init__()
self.varref = varref self.varref = varref
self.member = member self.member = member
@ -194,7 +192,7 @@ class AccessStaticArrayMember(Expression):
member: Union[Expression, TypeStaticArrayMember] member: Union[Expression, TypeStaticArrayMember]
def __init__(self, varref: Union['ModuleConstantReference', VariableReference], static_array: TypeStaticArray, member: Union[TypeStaticArrayMember, Expression], ) -> None: def __init__(self, varref: Union['ModuleConstantReference', VariableReference], static_array: TypeStaticArray, member: Union[TypeStaticArrayMember, Expression], ) -> None:
super().__init__(static_array.member_type) super().__init__()
self.varref = varref self.varref = varref
self.static_array = static_array self.static_array = static_array
@ -218,13 +216,12 @@ class Fold(Expression):
def __init__( def __init__(
self, self,
type_: TypeBase,
dir_: Direction, dir_: Direction,
func: 'Function', func: 'Function',
base: Expression, base: Expression,
iter_: Expression, iter_: Expression,
) -> None: ) -> None:
super().__init__(type_) super().__init__()
self.dir = dir_ self.dir = dir_
self.func = func self.func = func
@ -239,8 +236,8 @@ class ModuleConstantReference(Expression):
definition: 'ModuleConstantDef' definition: 'ModuleConstantDef'
def __init__(self, type_: TypeBase, definition: 'ModuleConstantDef') -> None: def __init__(self, definition: 'ModuleConstantDef') -> None:
super().__init__(type_) super().__init__()
self.definition = definition self.definition = definition
class Statement: class Statement:
@ -280,6 +277,9 @@ class StatementIf(Statement):
self.else_statements = [] self.else_statements = []
class FunctionParam: class FunctionParam:
"""
A parameter for a Function
"""
__slots__ = ('name', 'type', 'type_var', ) __slots__ = ('name', 'type', 'type_var', )
name: str name: str

View File

@ -7,13 +7,6 @@ import ast
from .typing import ( from .typing import (
TypeBase, TypeBase,
TypeUInt8,
TypeUInt32,
TypeUInt64,
TypeInt32,
TypeInt64,
TypeFloat32,
TypeFloat64,
TypeBytes, TypeBytes,
TypeStruct, TypeStruct,
TypeStructMember, TypeStructMember,
@ -34,7 +27,6 @@ from .ourlang import (
Expression, Expression,
AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember, AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember,
BinaryOp, BinaryOp,
Constant,
ConstantPrimitive, ConstantTuple, ConstantStaticArray, ConstantPrimitive, ConstantTuple, ConstantStaticArray,
FunctionCall, FunctionCall,
@ -242,7 +234,7 @@ class OurVisitor:
node.target.id, node.target.id,
node.lineno, node.lineno,
exp_type, exp_type,
ConstantTuple(exp_type, tuple_data), ConstantTuple(tuple_data),
data_block, data_block,
) )
@ -270,7 +262,7 @@ class OurVisitor:
node.target.id, node.target.id,
node.lineno, node.lineno,
exp_type, exp_type,
ConstantStaticArray(exp_type, static_array_data), ConstantStaticArray(static_array_data),
data_block, data_block,
) )
@ -359,7 +351,6 @@ class OurVisitor:
# e.g. you can do `"hello" * 3` with the code below (yet) # e.g. you can do `"hello" * 3` with the code below (yet)
return BinaryOp( return BinaryOp(
exp_type,
operator, operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right), self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right),
@ -374,7 +365,6 @@ class OurVisitor:
raise NotImplementedError(f'Operator {node.op}') raise NotImplementedError(f'Operator {node.op}')
return UnaryOp( return UnaryOp(
exp_type,
operator, operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand), self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand),
) )
@ -396,7 +386,6 @@ class OurVisitor:
# e.g. you can do `"hello" * 3` with the code below (yet) # e.g. you can do `"hello" * 3` with the code below (yet)
return BinaryOp( return BinaryOp(
exp_type,
operator, operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]), self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]),
@ -412,12 +401,12 @@ class OurVisitor:
if isinstance(node, ast.Attribute): if isinstance(node, ast.Attribute):
return self.visit_Module_FunctionDef_Attribute( return self.visit_Module_FunctionDef_Attribute(
module, function, our_locals, exp_type, node, module, function, our_locals, node,
) )
if isinstance(node, ast.Subscript): if isinstance(node, ast.Subscript):
return self.visit_Module_FunctionDef_Subscript( return self.visit_Module_FunctionDef_Subscript(
module, function, our_locals, exp_type, node, module, function, our_locals, node,
) )
if isinstance(node, ast.Name): if isinstance(node, ast.Name):
@ -426,21 +415,17 @@ class OurVisitor:
if node.id in our_locals: if node.id in our_locals:
param = our_locals[node.id] param = our_locals[node.id]
if exp_type != param.type: return VariableReference(param)
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(param.type)}')
return VariableReference(param.type, param)
if node.id in module.constant_defs: if node.id in module.constant_defs:
cdef = module.constant_defs[node.id] cdef = module.constant_defs[node.id]
if exp_type != cdef.type: return ModuleConstantReference(cdef)
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(cdef.type)}')
return ModuleConstantReference(exp_type, cdef)
_raise_static_error(node, f'Undefined variable {node.id}') _raise_static_error(node, f'Undefined variable {node.id}')
if isinstance(node, ast.Tuple): if isinstance(node, ast.Tuple):
raise NotImplementedError('TODO: Broken after new type system')
if not isinstance(node.ctx, ast.Load): if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
@ -478,40 +463,28 @@ class OurVisitor:
func = module.functions[struct_constructor.name] func = module.functions[struct_constructor.name]
elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS:
if not isinstance(exp_type, (TypeFloat32, TypeFloat64, )):
_raise_static_error(node, f'Cannot make {node.func.id} result in {codestyle.type_(exp_type)}')
if 1 != len(node.args): if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
return UnaryOp( return UnaryOp(
exp_type,
'sqrt', 'sqrt',
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]),
) )
elif node.func.id == 'u32': elif node.func.id == 'u32':
if not isinstance(exp_type, TypeUInt32):
_raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}')
if 1 != len(node.args): if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
# FIXME: This is a stub, proper casting is todo # FIXME: This is a stub, proper casting is todo
return UnaryOp( return UnaryOp(
exp_type,
'cast', 'cast',
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['u8'], node.args[0]), self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['u8'], node.args[0]),
) )
elif node.func.id == 'len': elif node.func.id == 'len':
if not isinstance(exp_type, TypeInt32):
_raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}')
if 1 != len(node.args): if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
return UnaryOp( return UnaryOp(
exp_type,
'len', 'len',
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]), self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]),
) )
@ -536,6 +509,8 @@ class OurVisitor:
if 2 != len(func.posonlyargs): if 2 != len(func.posonlyargs):
_raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given') _raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
raise NotImplementedError('TODO: Broken after new type system')
if exp_type.__class__ != func.returns.__class__: if exp_type.__class__ != func.returns.__class__:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
@ -546,7 +521,6 @@ class OurVisitor:
_raise_static_error(node, 'Only folding over bytes (u8) is supported at this time') _raise_static_error(node, 'Only folding over bytes (u8) is supported at this time')
return Fold( return Fold(
exp_type,
Fold.Direction.LEFT, Fold.Direction.LEFT,
func, func,
self.visit_Module_FunctionDef_expr(module, function, our_locals, func.returns, node.args[1]), self.visit_Module_FunctionDef_expr(module, function, our_locals, func.returns, node.args[1]),
@ -571,7 +545,7 @@ class OurVisitor:
) )
return result return result
def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Attribute) -> Expression: def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression:
del module del module
del function del function
@ -594,15 +568,12 @@ class OurVisitor:
if member is None: if member is None:
_raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}') _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}')
if exp_type != member.type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}')
return AccessStructMember( return AccessStructMember(
VariableReference(node_typ, param), VariableReference(param),
member, member,
) )
def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Subscript) -> Expression: def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression:
if not isinstance(node.value, ast.Name): if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name') _raise_static_error(node, 'Must reference a name')
@ -616,11 +587,11 @@ class OurVisitor:
if node.value.id in our_locals: if node.value.id in our_locals:
param = our_locals[node.value.id] param = our_locals[node.value.id]
node_typ = param.type node_typ = param.type
varref = VariableReference(param.type, param) varref = VariableReference(param)
elif node.value.id in module.constant_defs: elif node.value.id in module.constant_defs:
constant_def = module.constant_defs[node.value.id] constant_def = module.constant_defs[node.value.id]
node_typ = constant_def.type node_typ = constant_def.type
varref = ModuleConstantReference(node_typ, constant_def) varref = ModuleConstantReference(constant_def)
else: else:
_raise_static_error(node, f'Undefined variable {node.value.id}') _raise_static_error(node, f'Undefined variable {node.value.id}')
@ -629,15 +600,10 @@ class OurVisitor:
) )
if isinstance(node_typ, TypeBytes): if isinstance(node_typ, TypeBytes):
t_u8 = module.types['u8']
if exp_type != t_u8:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{codestyle.expression(slice_expr)}] is actually {codestyle.type_(t_u8)}')
if isinstance(varref, ModuleConstantReference): if isinstance(varref, ModuleConstantReference):
raise NotImplementedError(f'{node} from module constant') raise NotImplementedError(f'{node} from module constant')
return AccessBytesIndex( return AccessBytesIndex(
t_u8,
varref, varref,
slice_expr, slice_expr,
) )
@ -655,8 +621,6 @@ class OurVisitor:
_raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}')
tuple_member = node_typ.members[idx] tuple_member = node_typ.members[idx]
if exp_type != tuple_member.type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(tuple_member.type)}')
if isinstance(varref, ModuleConstantReference): if isinstance(varref, ModuleConstantReference):
raise NotImplementedError(f'{node} from module constant') raise NotImplementedError(f'{node} from module constant')
@ -667,9 +631,6 @@ class OurVisitor:
) )
if isinstance(node_typ, TypeStaticArray): if isinstance(node_typ, TypeStaticArray):
if exp_type != node_typ.member_type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}')
if not isinstance(slice_expr, ConstantPrimitive): if not isinstance(slice_expr, ConstantPrimitive):
return AccessStaticArrayMember( return AccessStaticArrayMember(
varref, varref,

View File

@ -26,7 +26,7 @@ def __find_free_block__(g: Generator, alloc_size: i32) -> i32:
g.i32.const(0) g.i32.const(0)
g.return_() g.return_()
del alloc_size # TODO del alloc_size # TODO: Actual implement using a previously freed block
g.unreachable() g.unreachable()
return i32('return') # To satisfy mypy return i32('return') # To satisfy mypy

View File

@ -66,19 +66,27 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
return inp.variable.type_var return inp.variable.type_var
if isinstance(inp, ourlang.UnaryOp): if isinstance(inp, ourlang.UnaryOp):
# TODO: Simplified version
if inp.operator not in ('sqrt', ): if inp.operator not in ('sqrt', ):
raise NotImplementedError(expression, inp, inp.operator) raise NotImplementedError(expression, inp, inp.operator)
right = expression(ctx, inp.right) right = expression(ctx, inp.right)
inp.type_var = right
return right return right
if isinstance(inp, ourlang.BinaryOp): if isinstance(inp, ourlang.BinaryOp):
# TODO: Simplified version
if inp.operator not in ('+', '-', '*', '|', '&', '^'): if inp.operator not in ('+', '-', '*', '|', '&', '^'):
raise NotImplementedError(expression, inp, inp.operator) raise NotImplementedError(expression, inp, inp.operator)
left = expression(ctx, inp.left) left = expression(ctx, inp.left)
right = expression(ctx, inp.right) right = expression(ctx, inp.right)
ctx.unify(left, right) ctx.unify(left, right)
inp.type_var = left
return left return left
if isinstance(inp, ourlang.FunctionCall): if isinstance(inp, ourlang.FunctionCall):
@ -94,6 +102,9 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
if isinstance(inp, ourlang.ModuleConstantReference): if isinstance(inp, ourlang.ModuleConstantReference):
assert inp.definition.type_var is not None assert inp.definition.type_var is not None
inp.type_var = inp.definition.type_var
return inp.definition.type_var return inp.definition.type_var
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
@ -133,30 +144,35 @@ def _convert_old_type(ctx: Context, inp: typing.TypeBase, location: str) -> Type
result = ctx.new_var() result = ctx.new_var()
if isinstance(inp, typing.TypeUInt8): if isinstance(inp, typing.TypeUInt8):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False)) result.add_constraint(TypeConstraintSigned(False))
result.add_location(location) result.add_location(location)
return result return result
if isinstance(inp, typing.TypeUInt32): if isinstance(inp, typing.TypeUInt32):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False)) result.add_constraint(TypeConstraintSigned(False))
result.add_location(location) result.add_location(location)
return result return result
if isinstance(inp, typing.TypeUInt64): if isinstance(inp, typing.TypeUInt64):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False)) result.add_constraint(TypeConstraintSigned(False))
result.add_location(location) result.add_location(location)
return result return result
if isinstance(inp, typing.TypeInt32): if isinstance(inp, typing.TypeInt32):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True)) result.add_constraint(TypeConstraintSigned(True))
result.add_location(location) result.add_location(location)
return result return result
if isinstance(inp, typing.TypeInt64): if isinstance(inp, typing.TypeInt64):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True)) result.add_constraint(TypeConstraintSigned(True))
result.add_location(location) result.add_location(location)

View File

@ -207,23 +207,45 @@ class TypeStruct(TypeBase):
## NEW STUFF BELOW ## NEW STUFF BELOW
# This error can also mean that the type somewhere forgot to write a type
# back to the AST. If so, we need to fix the typer.
ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method'
class TypingNarrowProtoError(TypingError): class TypingNarrowProtoError(TypingError):
pass """
A proto error when trying to narrow two types
This gets turned into a TypingNarrowError by the unify method
"""
# FIXME: Use consistent naming for unify / narrow / entangle
class TypingNarrowError(TypingError): class TypingNarrowError(TypingError):
"""
An error when trying to unify two Type Variables
"""
def __init__(self, l: 'TypeVar', r: 'TypeVar', msg: str) -> None: def __init__(self, l: 'TypeVar', r: 'TypeVar', msg: str) -> None:
super().__init__( super().__init__(
f'Cannot narrow types {l} and {r}: {msg}' f'Cannot narrow types {l} and {r}: {msg}'
) )
class TypeConstraintBase: class TypeConstraintBase:
"""
Base class for classes implementing a contraint on a type
"""
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase': def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase':
raise NotImplementedError('narrow', self, other) raise NotImplementedError('narrow', self, other)
class TypeConstraintPrimitive(TypeConstraintBase): class TypeConstraintPrimitive(TypeConstraintBase):
"""
This contraint on a type defines its primitive shape
"""
__slots__ = ('primitive', ) __slots__ = ('primitive', )
class Primitive(enum.Enum): class Primitive(enum.Enum):
"""
The primitive ID
"""
INT = 0 INT = 0
FLOAT = 1 FLOAT = 1
@ -245,6 +267,10 @@ class TypeConstraintPrimitive(TypeConstraintBase):
return f'Primitive={self.primitive.name}' return f'Primitive={self.primitive.name}'
class TypeConstraintSigned(TypeConstraintBase): class TypeConstraintSigned(TypeConstraintBase):
"""
Contraint on whether a signed value can be used or not, or whether
a value can be used in a signed expression
"""
__slots__ = ('signed', ) __slots__ = ('signed', )
signed: Optional[bool] signed: Optional[bool]
@ -270,6 +296,9 @@ class TypeConstraintSigned(TypeConstraintBase):
return f'Signed={self.signed}' return f'Signed={self.signed}'
class TypeConstraintBitWidth(TypeConstraintBase): class TypeConstraintBitWidth(TypeConstraintBase):
"""
Contraint on how many bits an expression has or can possibly have
"""
__slots__ = ('minb', 'maxb', ) __slots__ = ('minb', 'maxb', )
minb: int minb: int
@ -301,6 +330,10 @@ class TypeConstraintBitWidth(TypeConstraintBase):
return f'BitWidth={self.minb}..{self.maxb}' return f'BitWidth={self.minb}..{self.maxb}'
class TypeVar: class TypeVar:
"""
A type variable
"""
# FIXME: Explain the type system
__slots__ = ('ctx', 'ctx_id', ) __slots__ = ('ctx', 'ctx_id', )
ctx: 'Context' ctx: 'Context'
@ -331,6 +364,9 @@ class TypeVar:
) )
class Context: class Context:
"""
The context for a collection of type variables
"""
def __init__(self) -> None: def __init__(self) -> None:
# Variables are unified (or entangled, if you will) # Variables are unified (or entangled, if you will)
# that means that each TypeVar within a context has an ID, # that means that each TypeVar within a context has an ID,
@ -399,6 +435,10 @@ class Context:
del self.var_locations[r_ctx_id] del self.var_locations[r_ctx_id]
def simplify(inp: TypeVar) -> Optional[str]: def simplify(inp: TypeVar) -> Optional[str]:
"""
Simplifies a TypeVar into a string that wasm can work with
and users can recognize
"""
tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive) tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive)
tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth) tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth)
tc_sign = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintSigned) tc_sign = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintSigned)

View File

@ -1,7 +1,7 @@
""" """
Helper functions to quickly generate WASM code Helper functions to quickly generate WASM code
""" """
from typing import Any, Dict, List, Optional, Type from typing import List, Optional
import functools import functools

View File

@ -1,5 +1,5 @@
[MASTER] [MASTER]
disable=C0103,C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223 disable=C0103,C0122,R0902,R0903,R0911,R0912,R0913,R0915,R1710,W0223
max-line-length=180 max-line-length=180