Adds a separte typing system #3

Closed
jbwdevries wants to merge 18 commits from milner_type_checking into master
10 changed files with 184 additions and 182 deletions
Showing only changes of commit 564f00a419 - Show all commits

View File

@ -3,7 +3,7 @@ This module generates source code based on the parsed AST
It's intented to be a "any color, as long as it's black" kind of renderer
"""
from typing import Generator
from typing import Generator, Optional
from . import ourlang
from . import typing
@ -16,55 +16,17 @@ def phasm_render(inp: ourlang.Module) -> str:
Statements = Generator[str, None, None]
def type_(inp: typing.TypeBase) -> str:
def type_var(inp: Optional[typing.TypeVar]) -> str:
"""
Render: Type (name)
Render: type's name
"""
if isinstance(inp, typing.TypeNone):
return 'None'
assert inp is not None, typing.ASSERTION_ERROR
if isinstance(inp, typing.TypeBool):
return 'bool'
mtyp = typing.simplify(inp)
if mtyp is None:
raise NotImplementedError(f'Rendering type {inp}')
if isinstance(inp, typing.TypeUInt8):
return 'u8'
if isinstance(inp, typing.TypeUInt32):
return 'u32'
if isinstance(inp, typing.TypeUInt64):
return 'u64'
if isinstance(inp, typing.TypeInt32):
return 'i32'
if isinstance(inp, typing.TypeInt64):
return 'i64'
if isinstance(inp, typing.TypeFloat32):
return 'f32'
if isinstance(inp, typing.TypeFloat64):
return 'f64'
if isinstance(inp, typing.TypeBytes):
return 'bytes'
if isinstance(inp, typing.TypeTuple):
mems = ', '.join(
type_(x.type)
for x in inp.members
)
return f'({mems}, )'
if isinstance(inp, typing.TypeStaticArray):
return f'{type_(inp.member_type)}[{len(inp.members)}]'
if isinstance(inp, typing.TypeStruct):
return inp.name
raise NotImplementedError(type_, inp)
return mtyp
def struct_definition(inp: typing.TypeStruct) -> str:
"""
@ -72,7 +34,8 @@ def struct_definition(inp: typing.TypeStruct) -> str:
"""
result = f'class {inp.name}:\n'
for mem in inp.members:
result += f' {mem.name}: {type_(mem.type)}\n'
raise NotImplementedError('Structs broken after new type system')
# result += f' {mem.name}: {type_(mem.type)}\n'
return result
@ -80,7 +43,7 @@ def constant_definition(inp: ourlang.ModuleConstantDef) -> str:
"""
Render: Module Constant's definition
"""
return f'{inp.name}: {type_(inp.type)} = {expression(inp.constant)}\n'
return f'{inp.name}: {type_var(inp.type_var)} = {expression(inp.constant)}\n'
def expression(inp: ourlang.Expression) -> str:
"""
@ -107,7 +70,11 @@ def expression(inp: ourlang.Expression) -> str:
return f'{inp.operator}({expression(inp.right)})'
if inp.operator == 'cast':
return f'{type_(inp.type)}({expression(inp.right)})'
mtyp = type_var(inp.type_var)
if mtyp is None:
raise NotImplementedError(f'Casting to type {inp.type_var}')
return f'{mtyp}({expression(inp.right)})'
return f'{inp.operator}{expression(inp.right)}'
@ -187,11 +154,11 @@ def function(inp: ourlang.Function) -> str:
result += '@imported\n'
args = ', '.join(
f'{p.name}: {type_(p.type)}'
f'{p.name}: {type_var(p.type_var)}'
for p in inp.posonlyargs
)
result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n'
result += f'def {inp.name}({args}) -> {type_var(inp.returns_type_var)}:\n'
if inp.imported:
result += ' pass\n'

View File

@ -1,6 +1,8 @@
"""
This module contains the code to convert parsed Ourlang into WebAssembly code
"""
from typing import List
import struct
from . import codestyle
@ -132,7 +134,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
Compile: Any expression
"""
if isinstance(inp, ourlang.ConstantPrimitive):
assert inp.type_var is not None
assert inp.type_var is not None, typing.ASSERTION_ERROR
stp = typing.simplify(inp.type_var)
if stp is None:
@ -174,73 +176,80 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
expression(wgn, inp.left)
expression(wgn, inp.right)
if isinstance(inp.type, typing.TypeUInt8):
assert inp.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.type_var)
if mtyp == 'u8':
if operator := U8_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}')
return
if isinstance(inp.type, typing.TypeUInt32):
if mtyp == 'u32':
if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}')
return
if operator := U32_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}')
return
if isinstance(inp.type, typing.TypeUInt64):
if mtyp == 'u64':
if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}')
return
if operator := U64_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}')
return
if isinstance(inp.type, typing.TypeInt32):
if mtyp == 'i32':
if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}')
return
if operator := I32_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}')
return
if isinstance(inp.type, typing.TypeInt64):
if mtyp == 'i64':
if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}')
return
if operator := I64_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}')
return
if isinstance(inp.type, typing.TypeFloat32):
if mtyp == 'f32':
if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'f32.{operator}')
return
if isinstance(inp.type, typing.TypeFloat64):
if mtyp == 'f64':
if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'f64.{operator}')
return
raise NotImplementedError(expression, inp.type, inp.operator)
raise NotImplementedError(expression, inp.type_var, inp.operator)
if isinstance(inp, ourlang.UnaryOp):
expression(wgn, inp.right)
if isinstance(inp.type, typing.TypeFloat32):
assert inp.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.type_var)
if mtyp == 'f32':
if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS:
wgn.add_statement(f'f32.{inp.operator}')
return
if isinstance(inp.type, typing.TypeFloat64):
if mtyp == 'f64':
if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS:
wgn.add_statement(f'f64.{inp.operator}')
return
if isinstance(inp.type, typing.TypeInt32):
if inp.operator == 'len':
if isinstance(inp.right.type, typing.TypeBytes):
wgn.i32.load()
return
# TODO: Broken after new type system
# if isinstance(inp.type, typing.TypeInt32):
# if inp.operator == 'len':
# if isinstance(inp.right.type, typing.TypeBytes):
# wgn.i32.load()
# return
if inp.operator == 'cast':
if isinstance(inp.type, typing.TypeUInt32) and isinstance(inp.right.type, typing.TypeUInt8):
# Nothing to do, you can use an u8 value as a u32 no problem
return
# if inp.operator == 'cast':
# if isinstance(inp.type, typing.TypeUInt32) and isinstance(inp.right.type, typing.TypeUInt8):
# # Nothing to do, you can use an u8 value as a u32 no problem
# return
raise NotImplementedError(expression, inp.type, inp.operator)
raise NotImplementedError(expression, inp.type_var, inp.operator)
if isinstance(inp, ourlang.FunctionCall):
for arg in inp.arguments:
@ -249,14 +258,15 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
wgn.add_statement('call', '${}'.format(inp.function.name))
return
if isinstance(inp, ourlang.AccessBytesIndex):
if not isinstance(inp.type, typing.TypeUInt8):
raise NotImplementedError(inp, inp.type)
expression(wgn, inp.varref)
expression(wgn, inp.index)
wgn.call(stdlib_types.__subscript_bytes__)
return
# TODO: Broken after new type system
# if isinstance(inp, ourlang.AccessBytesIndex):
# if not isinstance(inp.type, typing.TypeUInt8):
# raise NotImplementedError(inp, inp.type)
#
# expression(wgn, inp.varref)
# expression(wgn, inp.index)
# wgn.call(stdlib_types.__subscript_bytes__)
# return
if isinstance(inp, ourlang.AccessStructMember):
mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__)
@ -305,27 +315,29 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
return
if isinstance(inp, ourlang.ModuleConstantReference):
if isinstance(inp.type, typing.TypeTuple):
assert isinstance(inp.definition.constant, ourlang.ConstantTuple)
assert inp.definition.data_block is not None, 'Combined values are memory stored'
assert inp.definition.data_block.address is not None, 'Value not allocated'
wgn.i32.const(inp.definition.data_block.address)
return
if isinstance(inp.type, typing.TypeStaticArray):
assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray)
assert inp.definition.data_block is not None, 'Combined values are memory stored'
assert inp.definition.data_block.address is not None, 'Value not allocated'
wgn.i32.const(inp.definition.data_block.address)
return
# FIXME: Tuple / Static Array broken after new type system
# if isinstance(inp.type, typing.TypeTuple):
# assert isinstance(inp.definition.constant, ourlang.ConstantTuple)
# assert inp.definition.data_block is not None, 'Combined values are memory stored'
# assert inp.definition.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.definition.data_block.address)
# return
#
# if isinstance(inp.type, typing.TypeStaticArray):
# assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray)
# assert inp.definition.data_block is not None, 'Combined values are memory stored'
# assert inp.definition.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.definition.data_block.address)
# return
assert inp.definition.data_block is None, 'Primitives are not memory stored'
mtyp = LOAD_STORE_TYPE_MAP.get(inp.type.__class__)
assert inp.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.type_var)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, inp.type)
raise NotImplementedError(expression, inp, inp.type_var)
expression(wgn, inp.definition.constant)
return
@ -336,13 +348,15 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
"""
Compile: Fold expression
"""
assert inp.base.type_var is not None
assert inp.base.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.base.type_var)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, inp.base)
raise NotImplementedError('TODO: Broken after new type system')
if inp.iter.type.__class__.__name__ != 'TypeBytes':
raise NotImplementedError(expression, inp, inp.iter.type)
@ -563,7 +577,9 @@ def module_data(inp: ourlang.ModuleData) -> bytes:
for block in inp.blocks:
block.address = unalloc_ptr + 4 # 4 bytes for allocator header
data_list = []
data_list: List[bytes] = []
raise NotImplementedError('Broken after new type system')
for constant in block.data:
if isinstance(constant, ourlang.ConstantUInt8):

View File

@ -8,4 +8,6 @@ class StaticError(Exception):
"""
class TypingError(Exception):
pass
"""
An error found during the typing phase
"""

View File

@ -1,7 +1,7 @@
"""
Contains the syntax tree for ourlang
"""
from typing import Dict, List, Tuple, Optional, Union
from typing import Dict, List, Optional, Union
import enum
@ -13,7 +13,6 @@ WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', )
from .typing import (
TypeBase,
TypeNone,
TypeBool,
TypeUInt8, TypeUInt32, TypeUInt64,
TypeInt32, TypeInt64,
TypeFloat32, TypeFloat64,
@ -29,13 +28,11 @@ class Expression:
"""
An expression within a statement
"""
__slots__ = ('type', 'type_var', )
__slots__ = ('type_var', )
type: TypeBase
type_var: Optional[TypeVar]
def __init__(self, type_: TypeBase) -> None:
self.type = type_
def __init__(self) -> None:
self.type_var = None
class Constant(Expression):
@ -53,6 +50,7 @@ class ConstantPrimitive(Constant):
value: Union[int, float]
def __init__(self, value: Union[int, float]) -> None:
super().__init__()
self.value = value
class ConstantTuple(Constant):
@ -63,8 +61,8 @@ class ConstantTuple(Constant):
value: List[ConstantPrimitive]
def __init__(self, type_: TypeTuple, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples?
super().__init__(type_)
def __init__(self, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples?
super().__init__()
self.value = value
class ConstantStaticArray(Constant):
@ -75,8 +73,8 @@ class ConstantStaticArray(Constant):
value: List[ConstantPrimitive]
def __init__(self, type_: TypeStaticArray, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays?
super().__init__(type_)
def __init__(self, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays?
super().__init__()
self.value = value
class VariableReference(Expression):
@ -87,8 +85,8 @@ class VariableReference(Expression):
variable: 'FunctionParam' # also possibly local
def __init__(self, type_: TypeBase, variable: 'FunctionParam') -> None:
super().__init__(type_)
def __init__(self, variable: 'FunctionParam') -> None:
super().__init__()
self.variable = variable
class UnaryOp(Expression):
@ -100,8 +98,8 @@ class UnaryOp(Expression):
operator: str
right: Expression
def __init__(self, type_: TypeBase, operator: str, right: Expression) -> None:
super().__init__(type_)
def __init__(self, operator: str, right: Expression) -> None:
super().__init__()
self.operator = operator
self.right = right
@ -116,8 +114,8 @@ class BinaryOp(Expression):
left: Expression
right: Expression
def __init__(self, type_: TypeBase, operator: str, left: Expression, right: Expression) -> None:
super().__init__(type_)
def __init__(self, operator: str, left: Expression, right: Expression) -> None:
super().__init__()
self.operator = operator
self.left = left
@ -133,7 +131,7 @@ class FunctionCall(Expression):
arguments: List[Expression]
def __init__(self, function: 'Function') -> None:
super().__init__(function.returns)
super().__init__()
self.function = function
self.arguments = []
@ -147,8 +145,8 @@ class AccessBytesIndex(Expression):
varref: VariableReference
index: Expression
def __init__(self, type_: TypeBase, varref: VariableReference, index: Expression) -> None:
super().__init__(type_)
def __init__(self, varref: VariableReference, index: Expression) -> None:
super().__init__()
self.varref = varref
self.index = index
@ -163,7 +161,7 @@ class AccessStructMember(Expression):
member: TypeStructMember
def __init__(self, varref: VariableReference, member: TypeStructMember) -> None:
super().__init__(member.type)
super().__init__()
self.varref = varref
self.member = member
@ -178,7 +176,7 @@ class AccessTupleMember(Expression):
member: TypeTupleMember
def __init__(self, varref: VariableReference, member: TypeTupleMember, ) -> None:
super().__init__(member.type)
super().__init__()
self.varref = varref
self.member = member
@ -194,7 +192,7 @@ class AccessStaticArrayMember(Expression):
member: Union[Expression, TypeStaticArrayMember]
def __init__(self, varref: Union['ModuleConstantReference', VariableReference], static_array: TypeStaticArray, member: Union[TypeStaticArrayMember, Expression], ) -> None:
super().__init__(static_array.member_type)
super().__init__()
self.varref = varref
self.static_array = static_array
@ -218,13 +216,12 @@ class Fold(Expression):
def __init__(
self,
type_: TypeBase,
dir_: Direction,
func: 'Function',
base: Expression,
iter_: Expression,
) -> None:
super().__init__(type_)
super().__init__()
self.dir = dir_
self.func = func
@ -239,8 +236,8 @@ class ModuleConstantReference(Expression):
definition: 'ModuleConstantDef'
def __init__(self, type_: TypeBase, definition: 'ModuleConstantDef') -> None:
super().__init__(type_)
def __init__(self, definition: 'ModuleConstantDef') -> None:
super().__init__()
self.definition = definition
class Statement:
@ -280,6 +277,9 @@ class StatementIf(Statement):
self.else_statements = []
class FunctionParam:
"""
A parameter for a Function
"""
__slots__ = ('name', 'type', 'type_var', )
name: str

View File

@ -7,13 +7,6 @@ import ast
from .typing import (
TypeBase,
TypeUInt8,
TypeUInt32,
TypeUInt64,
TypeInt32,
TypeInt64,
TypeFloat32,
TypeFloat64,
TypeBytes,
TypeStruct,
TypeStructMember,
@ -34,7 +27,6 @@ from .ourlang import (
Expression,
AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember,
BinaryOp,
Constant,
ConstantPrimitive, ConstantTuple, ConstantStaticArray,
FunctionCall,
@ -242,7 +234,7 @@ class OurVisitor:
node.target.id,
node.lineno,
exp_type,
ConstantTuple(exp_type, tuple_data),
ConstantTuple(tuple_data),
data_block,
)
@ -270,7 +262,7 @@ class OurVisitor:
node.target.id,
node.lineno,
exp_type,
ConstantStaticArray(exp_type, static_array_data),
ConstantStaticArray(static_array_data),
data_block,
)
@ -359,7 +351,6 @@ class OurVisitor:
# e.g. you can do `"hello" * 3` with the code below (yet)
return BinaryOp(
exp_type,
operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right),
@ -374,7 +365,6 @@ class OurVisitor:
raise NotImplementedError(f'Operator {node.op}')
return UnaryOp(
exp_type,
operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand),
)
@ -396,7 +386,6 @@ class OurVisitor:
# e.g. you can do `"hello" * 3` with the code below (yet)
return BinaryOp(
exp_type,
operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]),
@ -412,12 +401,12 @@ class OurVisitor:
if isinstance(node, ast.Attribute):
return self.visit_Module_FunctionDef_Attribute(
module, function, our_locals, exp_type, node,
module, function, our_locals, node,
)
if isinstance(node, ast.Subscript):
return self.visit_Module_FunctionDef_Subscript(
module, function, our_locals, exp_type, node,
module, function, our_locals, node,
)
if isinstance(node, ast.Name):
@ -426,21 +415,17 @@ class OurVisitor:
if node.id in our_locals:
param = our_locals[node.id]
if exp_type != param.type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(param.type)}')
return VariableReference(param.type, param)
return VariableReference(param)
if node.id in module.constant_defs:
cdef = module.constant_defs[node.id]
if exp_type != cdef.type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(cdef.type)}')
return ModuleConstantReference(exp_type, cdef)
return ModuleConstantReference(cdef)
_raise_static_error(node, f'Undefined variable {node.id}')
if isinstance(node, ast.Tuple):
raise NotImplementedError('TODO: Broken after new type system')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
@ -478,40 +463,28 @@ class OurVisitor:
func = module.functions[struct_constructor.name]
elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS:
if not isinstance(exp_type, (TypeFloat32, TypeFloat64, )):
_raise_static_error(node, f'Cannot make {node.func.id} result in {codestyle.type_(exp_type)}')
if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
return UnaryOp(
exp_type,
'sqrt',
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]),
)
elif node.func.id == 'u32':
if not isinstance(exp_type, TypeUInt32):
_raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}')
if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
# FIXME: This is a stub, proper casting is todo
return UnaryOp(
exp_type,
'cast',
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['u8'], node.args[0]),
)
elif node.func.id == 'len':
if not isinstance(exp_type, TypeInt32):
_raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}')
if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
return UnaryOp(
exp_type,
'len',
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]),
)
@ -536,6 +509,8 @@ class OurVisitor:
if 2 != len(func.posonlyargs):
_raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
raise NotImplementedError('TODO: Broken after new type system')
if exp_type.__class__ != func.returns.__class__:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
@ -546,7 +521,6 @@ class OurVisitor:
_raise_static_error(node, 'Only folding over bytes (u8) is supported at this time')
return Fold(
exp_type,
Fold.Direction.LEFT,
func,
self.visit_Module_FunctionDef_expr(module, function, our_locals, func.returns, node.args[1]),
@ -571,7 +545,7 @@ class OurVisitor:
)
return result
def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Attribute) -> Expression:
def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression:
del module
del function
@ -594,15 +568,12 @@ class OurVisitor:
if member is None:
_raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}')
if exp_type != member.type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}')
return AccessStructMember(
VariableReference(node_typ, param),
VariableReference(param),
member,
)
def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Subscript) -> Expression:
def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression:
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name')
@ -616,11 +587,11 @@ class OurVisitor:
if node.value.id in our_locals:
param = our_locals[node.value.id]
node_typ = param.type
varref = VariableReference(param.type, param)
varref = VariableReference(param)
elif node.value.id in module.constant_defs:
constant_def = module.constant_defs[node.value.id]
node_typ = constant_def.type
varref = ModuleConstantReference(node_typ, constant_def)
varref = ModuleConstantReference(constant_def)
else:
_raise_static_error(node, f'Undefined variable {node.value.id}')
@ -629,15 +600,10 @@ class OurVisitor:
)
if isinstance(node_typ, TypeBytes):
t_u8 = module.types['u8']
if exp_type != t_u8:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{codestyle.expression(slice_expr)}] is actually {codestyle.type_(t_u8)}')
if isinstance(varref, ModuleConstantReference):
raise NotImplementedError(f'{node} from module constant')
return AccessBytesIndex(
t_u8,
varref,
slice_expr,
)
@ -655,8 +621,6 @@ class OurVisitor:
_raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}')
tuple_member = node_typ.members[idx]
if exp_type != tuple_member.type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(tuple_member.type)}')
if isinstance(varref, ModuleConstantReference):
raise NotImplementedError(f'{node} from module constant')
@ -667,9 +631,6 @@ class OurVisitor:
)
if isinstance(node_typ, TypeStaticArray):
if exp_type != node_typ.member_type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}')
if not isinstance(slice_expr, ConstantPrimitive):
return AccessStaticArrayMember(
varref,

View File

@ -26,7 +26,7 @@ def __find_free_block__(g: Generator, alloc_size: i32) -> i32:
g.i32.const(0)
g.return_()
del alloc_size # TODO
del alloc_size # TODO: Actual implement using a previously freed block
g.unreachable()
return i32('return') # To satisfy mypy

View File

@ -66,19 +66,27 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
return inp.variable.type_var
if isinstance(inp, ourlang.UnaryOp):
# TODO: Simplified version
if inp.operator not in ('sqrt', ):
raise NotImplementedError(expression, inp, inp.operator)
right = expression(ctx, inp.right)
inp.type_var = right
return right
if isinstance(inp, ourlang.BinaryOp):
# TODO: Simplified version
if inp.operator not in ('+', '-', '*', '|', '&', '^'):
raise NotImplementedError(expression, inp, inp.operator)
left = expression(ctx, inp.left)
right = expression(ctx, inp.right)
ctx.unify(left, right)
inp.type_var = left
return left
if isinstance(inp, ourlang.FunctionCall):
@ -94,6 +102,9 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
if isinstance(inp, ourlang.ModuleConstantReference):
assert inp.definition.type_var is not None
inp.type_var = inp.definition.type_var
return inp.definition.type_var
raise NotImplementedError(expression, inp)
@ -133,30 +144,35 @@ def _convert_old_type(ctx: Context, inp: typing.TypeBase, location: str) -> Type
result = ctx.new_var()
if isinstance(inp, typing.TypeUInt8):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if isinstance(inp, typing.TypeUInt32):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if isinstance(inp, typing.TypeUInt64):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if isinstance(inp, typing.TypeInt32):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
if isinstance(inp, typing.TypeInt64):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)

View File

@ -207,23 +207,45 @@ class TypeStruct(TypeBase):
## NEW STUFF BELOW
# This error can also mean that the type somewhere forgot to write a type
# back to the AST. If so, we need to fix the typer.
ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method'
class TypingNarrowProtoError(TypingError):
pass
"""
A proto error when trying to narrow two types
This gets turned into a TypingNarrowError by the unify method
"""
# FIXME: Use consistent naming for unify / narrow / entangle
class TypingNarrowError(TypingError):
"""
An error when trying to unify two Type Variables
"""
def __init__(self, l: 'TypeVar', r: 'TypeVar', msg: str) -> None:
super().__init__(
f'Cannot narrow types {l} and {r}: {msg}'
)
class TypeConstraintBase:
"""
Base class for classes implementing a contraint on a type
"""
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase':
raise NotImplementedError('narrow', self, other)
class TypeConstraintPrimitive(TypeConstraintBase):
"""
This contraint on a type defines its primitive shape
"""
__slots__ = ('primitive', )
class Primitive(enum.Enum):
"""
The primitive ID
"""
INT = 0
FLOAT = 1
@ -245,6 +267,10 @@ class TypeConstraintPrimitive(TypeConstraintBase):
return f'Primitive={self.primitive.name}'
class TypeConstraintSigned(TypeConstraintBase):
"""
Contraint on whether a signed value can be used or not, or whether
a value can be used in a signed expression
"""
__slots__ = ('signed', )
signed: Optional[bool]
@ -270,6 +296,9 @@ class TypeConstraintSigned(TypeConstraintBase):
return f'Signed={self.signed}'
class TypeConstraintBitWidth(TypeConstraintBase):
"""
Contraint on how many bits an expression has or can possibly have
"""
__slots__ = ('minb', 'maxb', )
minb: int
@ -301,6 +330,10 @@ class TypeConstraintBitWidth(TypeConstraintBase):
return f'BitWidth={self.minb}..{self.maxb}'
class TypeVar:
"""
A type variable
"""
# FIXME: Explain the type system
__slots__ = ('ctx', 'ctx_id', )
ctx: 'Context'
@ -331,6 +364,9 @@ class TypeVar:
)
class Context:
"""
The context for a collection of type variables
"""
def __init__(self) -> None:
# Variables are unified (or entangled, if you will)
# that means that each TypeVar within a context has an ID,
@ -399,6 +435,10 @@ class Context:
del self.var_locations[r_ctx_id]
def simplify(inp: TypeVar) -> Optional[str]:
"""
Simplifies a TypeVar into a string that wasm can work with
and users can recognize
"""
tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive)
tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth)
tc_sign = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintSigned)

View File

@ -1,7 +1,7 @@
"""
Helper functions to quickly generate WASM code
"""
from typing import Any, Dict, List, Optional, Type
from typing import List, Optional
import functools

View File

@ -1,5 +1,5 @@
[MASTER]
disable=C0103,C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223
disable=C0103,C0122,R0902,R0903,R0911,R0912,R0913,R0915,R1710,W0223
max-line-length=180