Adds a separte typing system #3

Closed
jbwdevries wants to merge 18 commits from milner_type_checking into master
6 changed files with 504 additions and 601 deletions
Showing only changes of commit 07c0688d1b - Show all commits

View File

@ -33,7 +33,7 @@ def struct_definition(inp: typing.TypeStruct) -> str:
Render: TypeStruct's definition
"""
result = f'class {inp.name}:\n'
for mem in inp.members:
for mem in inp.members: # TODO: Broken after new type system
raise NotImplementedError('Structs broken after new type system')
# result += f' {mem.name}: {type_(mem.type)}\n'
@ -87,11 +87,12 @@ def expression(inp: ourlang.Expression) -> str:
for arg in inp.arguments
)
if isinstance(inp.function, ourlang.StructConstructor):
return f'{inp.function.struct.name}({args})'
if isinstance(inp.function, ourlang.TupleConstructor):
return f'({args}, )'
# TODO: Broken after new type system
# if isinstance(inp.function, ourlang.StructConstructor):
# return f'{inp.function.struct.name}({args})'
#
# if isinstance(inp.function, ourlang.TupleConstructor):
# return f'({args}, )'
return f'{inp.function.name}({args})'

View File

@ -1,7 +1,7 @@
"""
This module contains the code to convert parsed Ourlang into WebAssembly code
"""
from typing import List
from typing import List, Optional
import struct
@ -14,19 +14,6 @@ from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types
from .wasmgenerator import Generator as WasmGenerator
LOAD_STORE_TYPE_MAP = {
typing.TypeUInt8: 'i32',
typing.TypeUInt32: 'i32',
typing.TypeUInt64: 'i64',
typing.TypeInt32: 'i32',
typing.TypeInt64: 'i64',
typing.TypeFloat32: 'f32',
typing.TypeFloat64: 'f64',
}
"""
When generating code, we sometimes need to load or store simple values
"""
def phasm_compile(inp: ourlang.Module) -> wasm.Module:
"""
Public method for compiling a parsed Phasm module into
@ -34,42 +21,44 @@ def phasm_compile(inp: ourlang.Module) -> wasm.Module:
"""
return module(inp)
def type_(inp: typing.TypeBase) -> wasm.WasmType:
def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType:
"""
Compile: type
"""
if isinstance(inp, typing.TypeNone):
return wasm.WasmTypeNone()
assert inp is not None, typing.ASSERTION_ERROR
if isinstance(inp, typing.TypeUInt8):
mtyp = typing.simplify(inp)
if mtyp == 'u8':
# WebAssembly has only support for 32 and 64 bits
# So we need to store more memory per byte
return wasm.WasmTypeInt32()
if isinstance(inp, typing.TypeUInt32):
if mtyp == 'u32':
return wasm.WasmTypeInt32()
if isinstance(inp, typing.TypeUInt64):
if mtyp == 'u64':
return wasm.WasmTypeInt64()
if isinstance(inp, typing.TypeInt32):
if mtyp == 'i32':
return wasm.WasmTypeInt32()
if isinstance(inp, typing.TypeInt64):
if mtyp == 'i64':
return wasm.WasmTypeInt64()
if isinstance(inp, typing.TypeFloat32):
if mtyp == 'f32':
return wasm.WasmTypeFloat32()
if isinstance(inp, typing.TypeFloat64):
if mtyp == 'f64':
return wasm.WasmTypeFloat64()
if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)):
# Structs and tuples are passed as pointer
# And pointers are i32
return wasm.WasmTypeInt32()
# TODO: Broken after new type system
# if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)):
# # Structs and tuples are passed as pointer
# # And pointers are i32
# return wasm.WasmTypeInt32()
raise NotImplementedError(type_, inp)
raise NotImplementedError(inp, mtyp)
# Operators that work for i32, i64, f32, f64
OPERATOR_MAP = {
@ -268,47 +257,47 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
# wgn.call(stdlib_types.__subscript_bytes__)
# return
if isinstance(inp, ourlang.AccessStructMember):
mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, inp.member)
expression(wgn, inp.varref)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
return
if isinstance(inp, ourlang.AccessTupleMember):
mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, inp.member)
expression(wgn, inp.varref)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
return
if isinstance(inp, ourlang.AccessStaticArrayMember):
mtyp = LOAD_STORE_TYPE_MAP.get(inp.static_array.member_type.__class__)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of static arrays
raise NotImplementedError(expression, inp, inp.member)
if isinstance(inp.member, typing.TypeStaticArrayMember):
expression(wgn, inp.varref)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
return
expression(wgn, inp.varref)
expression(wgn, inp.member)
wgn.i32.const(inp.static_array.member_type.alloc_size())
wgn.i32.mul()
wgn.i32.add()
wgn.add_statement(f'{mtyp}.load')
return
# if isinstance(inp, ourlang.AccessStructMember):
# mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__)
# if mtyp is None:
# # In the future might extend this by having structs or tuples
# # as members of struct or tuples
# raise NotImplementedError(expression, inp, inp.member)
#
# expression(wgn, inp.varref)
# wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
# return
#
# if isinstance(inp, ourlang.AccessTupleMember):
# mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__)
# if mtyp is None:
# # In the future might extend this by having structs or tuples
# # as members of struct or tuples
# raise NotImplementedError(expression, inp, inp.member)
#
# expression(wgn, inp.varref)
# wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
# return
#
# if isinstance(inp, ourlang.AccessStaticArrayMember):
# mtyp = LOAD_STORE_TYPE_MAP.get(inp.static_array.member_type.__class__)
# if mtyp is None:
# # In the future might extend this by having structs or tuples
# # as members of static arrays
# raise NotImplementedError(expression, inp, inp.member)
#
# if isinstance(inp.member, typing.TypeStaticArrayMember):
# expression(wgn, inp.varref)
# wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
# return
#
# expression(wgn, inp.varref)
# expression(wgn, inp.member)
# wgn.i32.const(inp.static_array.member_type.alloc_size())
# wgn.i32.mul()
# wgn.i32.add()
# wgn.add_statement(f'{mtyp}.load')
# return
if isinstance(inp, ourlang.Fold):
expression_fold(wgn, inp)
@ -472,7 +461,7 @@ def function_argument(inp: ourlang.FunctionParam) -> wasm.Param:
"""
Compile: function argument
"""
return (inp.name, type_(inp.type), )
return (inp.name, type_var(inp.type_var), )
def import_(inp: ourlang.Function) -> wasm.Import:
"""
@ -488,7 +477,7 @@ def import_(inp: ourlang.Function) -> wasm.Import:
function_argument(x)
for x in inp.posonlyargs
],
type_(inp.returns)
type_var(inp.returns_type_var)
)
def function(inp: ourlang.Function) -> wasm.Function:
@ -499,10 +488,10 @@ def function(inp: ourlang.Function) -> wasm.Function:
wgn = WasmGenerator()
if isinstance(inp, ourlang.TupleConstructor):
_generate_tuple_constructor(wgn, inp)
elif isinstance(inp, ourlang.StructConstructor):
_generate_struct_constructor(wgn, inp)
if False: # TODO: isinstance(inp, ourlang.TupleConstructor):
pass # _generate_tuple_constructor(wgn, inp)
elif False: # TODO: isinstance(inp, ourlang.StructConstructor):
pass # _generate_struct_constructor(wgn, inp)
else:
for stat in inp.statements:
statement(wgn, stat)
@ -518,7 +507,7 @@ def function(inp: ourlang.Function) -> wasm.Function:
(k, v.wasm_type(), )
for k, v in wgn.locals.items()
],
type_(inp.returns),
type_var(inp.returns_type_var),
wgn.statements
)
@ -660,48 +649,49 @@ def module(inp: ourlang.Module) -> wasm.Module:
return result
def _generate_tuple_constructor(wgn: WasmGenerator, inp: ourlang.TupleConstructor) -> None:
tmp_var = wgn.temp_var_i32('tuple_adr')
# Allocated the required amounts of bytes in memory
wgn.i32.const(inp.tuple.alloc_size())
wgn.call(stdlib_alloc.__alloc__)
wgn.local.set(tmp_var)
# Store each member individually
for member in inp.tuple.members:
mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, member)
wgn.local.get(tmp_var)
wgn.add_statement('local.get', f'$arg{member.idx}')
wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset))
# Return the allocated address
wgn.local.get(tmp_var)
def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None:
tmp_var = wgn.temp_var_i32('struct_adr')
# Allocated the required amounts of bytes in memory
wgn.i32.const(inp.struct.alloc_size())
wgn.call(stdlib_alloc.__alloc__)
wgn.local.set(tmp_var)
# Store each member individually
for member in inp.struct.members:
mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, member)
wgn.local.get(tmp_var)
wgn.add_statement('local.get', f'${member.name}')
wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset))
# Return the allocated address
wgn.local.get(tmp_var)
# TODO: Broken after new type system
# def _generate_tuple_constructor(wgn: WasmGenerator, inp: ourlang.TupleConstructor) -> None:
# tmp_var = wgn.temp_var_i32('tuple_adr')
#
# # Allocated the required amounts of bytes in memory
# wgn.i32.const(inp.tuple.alloc_size())
# wgn.call(stdlib_alloc.__alloc__)
# wgn.local.set(tmp_var)
#
# # Store each member individually
# for member in inp.tuple.members:
# mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__)
# if mtyp is None:
# # In the future might extend this by having structs or tuples
# # as members of struct or tuples
# raise NotImplementedError(expression, inp, member)
#
# wgn.local.get(tmp_var)
# wgn.add_statement('local.get', f'$arg{member.idx}')
# wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset))
#
# # Return the allocated address
# wgn.local.get(tmp_var)
#
# def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None:
# tmp_var = wgn.temp_var_i32('struct_adr')
#
# # Allocated the required amounts of bytes in memory
# wgn.i32.const(inp.struct.alloc_size())
# wgn.call(stdlib_alloc.__alloc__)
# wgn.local.set(tmp_var)
#
# # Store each member individually
# for member in inp.struct.members:
# mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__)
# if mtyp is None:
# # In the future might extend this by having structs or tuples
# # as members of struct or tuples
# raise NotImplementedError(expression, inp, member)
#
# wgn.local.get(tmp_var)
# wgn.add_statement('local.get', f'${member.name}')
# wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset))
#
# # Return the allocated address
# wgn.local.get(tmp_var)

View File

@ -11,11 +11,6 @@ WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc',
WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', )
from .typing import (
TypeBase,
TypeNone,
TypeUInt8, TypeUInt32, TypeUInt64,
TypeInt32, TypeInt64,
TypeFloat32, TypeFloat64,
TypeBytes,
TypeTuple, TypeTupleMember,
TypeStaticArray, TypeStaticArrayMember,
@ -280,29 +275,29 @@ class FunctionParam:
"""
A parameter for a Function
"""
__slots__ = ('name', 'type', 'type_var', )
__slots__ = ('name', 'type_str', 'type_var', )
name: str
type: TypeBase
type_str: str
type_var: Optional[TypeVar]
def __init__(self, name: str, type_: TypeBase) -> None:
def __init__(self, name: str, type_str: str) -> None:
self.name = name
self.type = type_
self.type_str = type_str
self.type_var = None
class Function:
"""
A function processes input and produces output
"""
__slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'returns_type_var', 'posonlyargs', )
__slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns_str', 'returns_type_var', 'posonlyargs', )
name: str
lineno: int
exported: bool
imported: bool
statements: List[Statement]
returns: TypeBase
returns_str: str
returns_type_var: Optional[TypeVar]
posonlyargs: List[FunctionParam]
@ -312,68 +307,67 @@ class Function:
self.exported = False
self.imported = False
self.statements = []
self.returns = TypeNone()
self.returns_str = 'None'
self.returns_type_var = None
self.posonlyargs = []
class StructConstructor(Function):
"""
The constructor method for a struct
A function will generated to instantiate a struct. The arguments
will be the defaults
"""
__slots__ = ('struct', )
struct: TypeStruct
def __init__(self, struct: TypeStruct) -> None:
super().__init__(f'@{struct.name}@__init___@', -1)
self.returns = struct
for mem in struct.members:
self.posonlyargs.append(FunctionParam(mem.name, mem.type, ))
self.struct = struct
class TupleConstructor(Function):
"""
The constructor method for a tuple
"""
__slots__ = ('tuple', )
tuple: TypeTuple
def __init__(self, tuple_: TypeTuple) -> None:
name = tuple_.render_internal_name()
super().__init__(f'@{name}@__init___@', -1)
self.returns = tuple_
for mem in tuple_.members:
self.posonlyargs.append(FunctionParam(f'arg{mem.idx}', mem.type, ))
self.tuple = tuple_
# TODO: Broken after new type system
# class StructConstructor(Function):
# """
# The constructor method for a struct
#
# A function will generated to instantiate a struct. The arguments
# will be the defaults
# """
# __slots__ = ('struct', )
#
# struct: TypeStruct
#
# def __init__(self, struct: TypeStruct) -> None:
# super().__init__(f'@{struct.name}@__init___@', -1)
#
# self.returns = struct
#
# for mem in struct.members:
# self.posonlyargs.append(FunctionParam(mem.name, mem.type, ))
#
# self.struct = struct
#
# class TupleConstructor(Function):
# """
# The constructor method for a tuple
# """
# __slots__ = ('tuple', )
#
# tuple: TypeTuple
#
# def __init__(self, tuple_: TypeTuple) -> None:
# name = tuple_.render_internal_name()
#
# super().__init__(f'@{name}@__init___@', -1)
#
# self.returns = tuple_
#
# for mem in tuple_.members:
# self.posonlyargs.append(FunctionParam(f'arg{mem.idx}', mem.type, ))
#
# self.tuple = tuple_
class ModuleConstantDef:
"""
A constant definition within a module
"""
__slots__ = ('name', 'lineno', 'type', 'type_var', 'constant', 'data_block', )
__slots__ = ('name', 'lineno', 'type_var', 'constant', 'data_block', )
name: str
lineno: int
type: TypeBase
type_var: Optional[TypeVar]
constant: Constant
data_block: Optional['ModuleDataBlock']
def __init__(self, name: str, lineno: int, type_: TypeBase, constant: Constant, data_block: Optional['ModuleDataBlock']) -> None:
def __init__(self, name: str, lineno: int, constant: Constant, data_block: Optional['ModuleDataBlock']) -> None:
self.name = name
self.lineno = lineno
self.type = type_
self.type_var = None
self.constant = constant
self.data_block = data_block
@ -409,23 +403,11 @@ class Module:
__slots__ = ('data', 'types', 'structs', 'constant_defs', 'functions',)
data: ModuleData
types: Dict[str, TypeBase]
structs: Dict[str, TypeStruct]
constant_defs: Dict[str, ModuleConstantDef]
functions: Dict[str, Function]
def __init__(self) -> None:
self.types = {
'None': TypeNone(),
'u8': TypeUInt8(),
'u32': TypeUInt32(),
'u64': TypeUInt64(),
'i32': TypeInt32(),
'i64': TypeInt64(),
'f32': TypeFloat32(),
'f64': TypeFloat64(),
'bytes': TypeBytes(),
}
self.data = ModuleData()
self.structs = {}
self.constant_defs = {}

View File

@ -6,8 +6,6 @@ from typing import Any, Dict, NoReturn, Union
import ast
from .typing import (
TypeBase,
TypeBytes,
TypeStruct,
TypeStructMember,
TypeTuple,
@ -16,7 +14,6 @@ from .typing import (
TypeStaticArrayMember,
)
from . import codestyle
from .exceptions import StaticError
from .ourlang import (
WEBASSEMBLY_BUILDIN_FLOAT_OPS,
@ -30,7 +27,7 @@ from .ourlang import (
ConstantPrimitive, ConstantTuple, ConstantStaticArray,
FunctionCall,
StructConstructor, TupleConstructor,
# StructConstructor, TupleConstructor,
UnaryOp, VariableReference,
Fold, ModuleConstantReference,
@ -86,15 +83,16 @@ class OurVisitor:
module.constant_defs[res.name] = res
if isinstance(res, TypeStruct):
if res.name in module.structs:
raise StaticError(
f'{res.name} already defined on line {module.structs[res.name].lineno}'
)
module.structs[res.name] = res
constructor = StructConstructor(res)
module.functions[constructor.name] = constructor
# TODO: Broken after type system
# if isinstance(res, TypeStruct):
# if res.name in module.structs:
# raise StaticError(
# f'{res.name} already defined on line {module.structs[res.name].lineno}'
# )
#
# module.structs[res.name] = res
# constructor = StructConstructor(res)
# module.functions[constructor.name] = constructor
if isinstance(res, Function):
if res.name in module.functions:
@ -158,7 +156,7 @@ class OurVisitor:
function.imported = True
if node.returns:
function.returns = self.visit_type(module, node.returns)
function.returns_str = self.visit_type(module, node.returns)
_not_implemented(not node.type_comment, 'FunctionDef.type_comment')
@ -186,6 +184,7 @@ class OurVisitor:
if stmt.simple != 1:
raise NotImplementedError('Class with non-simple arguments')
raise NotImplementedError('TODO: Broken after new type system')
member = TypeStructMember(stmt.target.id, self.visit_type(module, stmt.annotation), offset)
struct.members.append(member)
@ -199,74 +198,72 @@ class OurVisitor:
if not isinstance(node.target.ctx, ast.Store):
_raise_static_error(node, 'Must be load context')
exp_type = self.visit_type(module, node.annotation)
if isinstance(node.value, ast.Constant):
return ModuleConstantDef(
node.target.id,
node.lineno,
exp_type,
self.visit_Module_Constant(module, node.value),
None,
)
if isinstance(exp_type, TypeTuple):
if not isinstance(node.value, ast.Tuple):
_raise_static_error(node, 'Must be tuple')
raise NotImplementedError('TODO: Broken after new typing system')
if len(exp_type.members) != len(node.value.elts):
_raise_static_error(node, 'Invalid number of tuple values')
tuple_data = [
self.visit_Module_Constant(module, arg_node)
for arg_node, mem in zip(node.value.elts, exp_type.members)
if isinstance(arg_node, ast.Constant)
]
if len(exp_type.members) != len(tuple_data):
_raise_static_error(node, 'Tuple arguments must be constants')
# Allocate the data
data_block = ModuleDataBlock(tuple_data)
module.data.blocks.append(data_block)
# Then return the constant as a pointer
return ModuleConstantDef(
node.target.id,
node.lineno,
exp_type,
ConstantTuple(tuple_data),
data_block,
)
if isinstance(exp_type, TypeStaticArray):
if not isinstance(node.value, ast.Tuple):
_raise_static_error(node, 'Must be static array')
if len(exp_type.members) != len(node.value.elts):
_raise_static_error(node, 'Invalid number of static array values')
static_array_data = [
self.visit_Module_Constant(module, arg_node)
for arg_node in node.value.elts
if isinstance(arg_node, ast.Constant)
]
if len(exp_type.members) != len(static_array_data):
_raise_static_error(node, 'Static array arguments must be constants')
# Allocate the data
data_block = ModuleDataBlock(static_array_data)
module.data.blocks.append(data_block)
# Then return the constant as a pointer
return ModuleConstantDef(
node.target.id,
node.lineno,
exp_type,
ConstantStaticArray(static_array_data),
data_block,
)
raise NotImplementedError(f'{node} on Module AnnAssign')
# if isinstance(exp_type, TypeTuple):
# if not isinstance(node.value, ast.Tuple):
# _raise_static_error(node, 'Must be tuple')
#
# if len(exp_type.members) != len(node.value.elts):
# _raise_static_error(node, 'Invalid number of tuple values')
#
# tuple_data = [
# self.visit_Module_Constant(module, arg_node)
# for arg_node, mem in zip(node.value.elts, exp_type.members)
# if isinstance(arg_node, ast.Constant)
# ]
# if len(exp_type.members) != len(tuple_data):
# _raise_static_error(node, 'Tuple arguments must be constants')
#
# # Allocate the data
# data_block = ModuleDataBlock(tuple_data)
# module.data.blocks.append(data_block)
#
# # Then return the constant as a pointer
# return ModuleConstantDef(
# node.target.id,
# node.lineno,
# exp_type,
# ConstantTuple(tuple_data),
# data_block,
# )
#
# if isinstance(exp_type, TypeStaticArray):
# if not isinstance(node.value, ast.Tuple):
# _raise_static_error(node, 'Must be static array')
#
# if len(exp_type.members) != len(node.value.elts):
# _raise_static_error(node, 'Invalid number of static array values')
#
# static_array_data = [
# self.visit_Module_Constant(module, arg_node)
# for arg_node in node.value.elts
# if isinstance(arg_node, ast.Constant)
# ]
# if len(exp_type.members) != len(static_array_data):
# _raise_static_error(node, 'Static array arguments must be constants')
#
# # Allocate the data
# data_block = ModuleDataBlock(static_array_data)
# module.data.blocks.append(data_block)
#
# # Then return the constant as a pointer
# return ModuleConstantDef(
# node.target.id,
# node.lineno,
# ConstantStaticArray(static_array_data),
# data_block,
# )
#
# raise NotImplementedError(f'{node} on Module AnnAssign')
def visit_Module_stmt(self, module: Module, node: ast.stmt) -> None:
if isinstance(node, ast.FunctionDef):
@ -301,12 +298,12 @@ class OurVisitor:
_raise_static_error(node, 'Return must have an argument')
return StatementReturn(
self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.value)
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.value)
)
if isinstance(node, ast.If):
result = StatementIf(
self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.test)
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.test)
)
for stmt in node.body:
@ -326,7 +323,7 @@ class OurVisitor:
raise NotImplementedError(f'{node} as stmt in FunctionDef')
def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.expr) -> Expression:
def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, node: ast.expr) -> Expression:
if isinstance(node, ast.BinOp):
if isinstance(node.op, ast.Add):
operator = '+'
@ -352,8 +349,8 @@ class OurVisitor:
return BinaryOp(
operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right),
)
if isinstance(node, ast.UnaryOp):
@ -366,7 +363,7 @@ class OurVisitor:
return UnaryOp(
operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.operand),
)
if isinstance(node, ast.Compare):
@ -387,12 +384,12 @@ class OurVisitor:
return BinaryOp(
operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]),
)
if isinstance(node, ast.Call):
return self.visit_Module_FunctionDef_Call(module, function, our_locals, exp_type, node)
return self.visit_Module_FunctionDef_Call(module, function, our_locals, node)
if isinstance(node, ast.Constant):
return self.visit_Module_Constant(
@ -426,29 +423,29 @@ class OurVisitor:
if isinstance(node, ast.Tuple):
raise NotImplementedError('TODO: Broken after new type system')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if isinstance(exp_type, TypeTuple):
if len(exp_type.members) != len(node.elts):
_raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given')
tuple_constructor = TupleConstructor(exp_type)
func = module.functions[tuple_constructor.name]
result = FunctionCall(func)
result.arguments = [
self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node)
for arg_node, mem in zip(node.elts, exp_type.members)
]
return result
_raise_static_error(node, f'Expression is expecting a {codestyle.type_(exp_type)}, not a tuple')
# if not isinstance(node.ctx, ast.Load):
# _raise_static_error(node, 'Must be load context')
#
# if isinstance(exp_type, TypeTuple):
# if len(exp_type.members) != len(node.elts):
# _raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given')
#
# tuple_constructor = TupleConstructor(exp_type)
#
# func = module.functions[tuple_constructor.name]
#
# result = FunctionCall(func)
# result.arguments = [
# self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node)
# for arg_node, mem in zip(node.elts, exp_type.members)
# ]
# return result
#
# _raise_static_error(node, f'Expression is expecting a {codestyle.type_(exp_type)}, not a tuple')
raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]:
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]:
if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet?
@ -458,17 +455,18 @@ class OurVisitor:
_raise_static_error(node, 'Must be load context')
if node.func.id in module.structs:
struct = module.structs[node.func.id]
struct_constructor = StructConstructor(struct)
func = module.functions[struct_constructor.name]
raise NotImplementedError('TODO: Broken after new type system')
# struct = module.structs[node.func.id]
# struct_constructor = StructConstructor(struct)
#
# func = module.functions[struct_constructor.name]
elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS:
if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
return UnaryOp(
'sqrt',
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]),
)
elif node.func.id == 'u32':
if 1 != len(node.args):
@ -478,7 +476,7 @@ class OurVisitor:
return UnaryOp(
'cast',
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['u8'], node.args[0]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]),
)
elif node.func.id == 'len':
if 1 != len(node.args):
@ -486,7 +484,7 @@ class OurVisitor:
return UnaryOp(
'len',
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]),
)
elif node.func.id == 'foldl':
# TODO: This should a much more generic function!
@ -511,20 +509,11 @@ class OurVisitor:
raise NotImplementedError('TODO: Broken after new type system')
if exp_type.__class__ != func.returns.__class__:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
if func.returns.__class__ != func.posonlyargs[0].type.__class__:
_raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0].type)}')
if module.types['u8'].__class__ != func.posonlyargs[1].type.__class__:
_raise_static_error(node, 'Only folding over bytes (u8) is supported at this time')
return Fold(
Fold.Direction.LEFT,
func,
self.visit_Module_FunctionDef_expr(module, function, our_locals, func.returns, node.args[1]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[2]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
)
else:
if node.func.id not in module.functions:
@ -532,20 +521,18 @@ class OurVisitor:
func = module.functions[node.func.id]
# if func.returns != exp_type:
# _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
if len(func.posonlyargs) != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given')
result = FunctionCall(func)
result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, param.type, arg_expr)
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
for arg_expr, param in zip(node.args, func.posonlyargs)
)
return result
def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression:
raise NotImplementedError('Broken after new type system')
del module
del function
@ -574,87 +561,89 @@ class OurVisitor:
)
def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression:
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name')
raise NotImplementedError('TODO: Broken after new type system')
if not isinstance(node.slice, ast.Index):
_raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
varref: Union[ModuleConstantReference, VariableReference]
if node.value.id in our_locals:
param = our_locals[node.value.id]
node_typ = param.type
varref = VariableReference(param)
elif node.value.id in module.constant_defs:
constant_def = module.constant_defs[node.value.id]
node_typ = constant_def.type
varref = ModuleConstantReference(constant_def)
else:
_raise_static_error(node, f'Undefined variable {node.value.id}')
slice_expr = self.visit_Module_FunctionDef_expr(
module, function, our_locals, module.types['u32'], node.slice.value,
)
if isinstance(node_typ, TypeBytes):
if isinstance(varref, ModuleConstantReference):
raise NotImplementedError(f'{node} from module constant')
return AccessBytesIndex(
varref,
slice_expr,
)
if isinstance(node_typ, TypeTuple):
if not isinstance(slice_expr, ConstantPrimitive):
_raise_static_error(node, 'Must subscript using a constant index')
idx = slice_expr.value
if not isinstance(idx, int):
_raise_static_error(node, 'Must subscript using a constant integer index')
if not (0 <= idx < len(node_typ.members)):
_raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}')
tuple_member = node_typ.members[idx]
if isinstance(varref, ModuleConstantReference):
raise NotImplementedError(f'{node} from module constant')
return AccessTupleMember(
varref,
tuple_member,
)
if isinstance(node_typ, TypeStaticArray):
if not isinstance(slice_expr, ConstantPrimitive):
return AccessStaticArrayMember(
varref,
node_typ,
slice_expr,
)
idx = slice_expr.value
if not isinstance(idx, int):
_raise_static_error(node, 'Must subscript using an integer index')
if not (0 <= idx < len(node_typ.members)):
_raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}')
static_array_member = node_typ.members[idx]
return AccessStaticArrayMember(
varref,
node_typ,
static_array_member,
)
_raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}')
# if not isinstance(node.value, ast.Name):
# _raise_static_error(node, 'Must reference a name')
#
# if not isinstance(node.slice, ast.Index):
# _raise_static_error(node, 'Must subscript using an index')
#
# if not isinstance(node.ctx, ast.Load):
# _raise_static_error(node, 'Must be load context')
#
# varref: Union[ModuleConstantReference, VariableReference]
# if node.value.id in our_locals:
# param = our_locals[node.value.id]
# node_typ = param.type
# varref = VariableReference(param)
# elif node.value.id in module.constant_defs:
# constant_def = module.constant_defs[node.value.id]
# node_typ = constant_def.type
# varref = ModuleConstantReference(constant_def)
# else:
# _raise_static_error(node, f'Undefined variable {node.value.id}')
#
# slice_expr = self.visit_Module_FunctionDef_expr(
# module, function, our_locals, node.slice.value,
# )
#
# if isinstance(node_typ, TypeBytes):
# if isinstance(varref, ModuleConstantReference):
# raise NotImplementedError(f'{node} from module constant')
#
# return AccessBytesIndex(
# varref,
# slice_expr,
# )
#
# if isinstance(node_typ, TypeTuple):
# if not isinstance(slice_expr, ConstantPrimitive):
# _raise_static_error(node, 'Must subscript using a constant index')
#
# idx = slice_expr.value
#
# if not isinstance(idx, int):
# _raise_static_error(node, 'Must subscript using a constant integer index')
#
# if not (0 <= idx < len(node_typ.members)):
# _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}')
#
# tuple_member = node_typ.members[idx]
#
# if isinstance(varref, ModuleConstantReference):
# raise NotImplementedError(f'{node} from module constant')
#
# return AccessTupleMember(
# varref,
# tuple_member,
# )
#
# if isinstance(node_typ, TypeStaticArray):
# if not isinstance(slice_expr, ConstantPrimitive):
# return AccessStaticArrayMember(
# varref,
# node_typ,
# slice_expr,
# )
#
# idx = slice_expr.value
#
# if not isinstance(idx, int):
# _raise_static_error(node, 'Must subscript using an integer index')
#
# if not (0 <= idx < len(node_typ.members)):
# _raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}')
#
# static_array_member = node_typ.members[idx]
#
# return AccessStaticArrayMember(
# varref,
# node_typ,
# static_array_member,
# )
#
# _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}')
def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive:
del module
@ -666,10 +655,10 @@ class OurVisitor:
raise NotImplementedError(f'{node.value} as constant')
def visit_type(self, module: Module, node: ast.expr) -> TypeBase:
def visit_type(self, module: Module, node: ast.expr) -> str:
if isinstance(node, ast.Constant):
if node.value is None:
return module.types['None']
return 'None'
_raise_static_error(node, f'Unrecognized type {node.value}')
@ -677,8 +666,10 @@ class OurVisitor:
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if node.id in module.types:
return module.types[node.id]
if node.id in ('u8', 'u32', 'u64', 'i32', 'i64', 'f32', 'f64'): # FIXME: Source this list somewhere
return node.id
raise NotImplementedError('TODO: Broken after type system')
if node.id in module.structs:
return module.structs[node.id]
@ -686,61 +677,65 @@ class OurVisitor:
_raise_static_error(node, f'Unrecognized type {node.id}')
if isinstance(node, ast.Subscript):
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must be name')
if not isinstance(node.slice, ast.Index):
_raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.slice.value, ast.Constant):
_raise_static_error(node, 'Must subscript using a constant index')
if not isinstance(node.slice.value.value, int):
_raise_static_error(node, 'Must subscript using a constant integer index')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
raise NotImplementedError('TODO: Broken after new type system')
if node.value.id in module.types:
member_type = module.types[node.value.id]
else:
_raise_static_error(node, f'Unrecognized type {node.value.id}')
type_static_array = TypeStaticArray(member_type)
offset = 0
for idx in range(node.slice.value.value):
static_array_member = TypeStaticArrayMember(idx, offset)
type_static_array.members.append(static_array_member)
offset += member_type.alloc_size()
key = f'{node.value.id}[{node.slice.value.value}]'
if key not in module.types:
module.types[key] = type_static_array
return module.types[key]
# if not isinstance(node.value, ast.Name):
# _raise_static_error(node, 'Must be name')
# if not isinstance(node.slice, ast.Index):
# _raise_static_error(node, 'Must subscript using an index')
# if not isinstance(node.slice.value, ast.Constant):
# _raise_static_error(node, 'Must subscript using a constant index')
# if not isinstance(node.slice.value.value, int):
# _raise_static_error(node, 'Must subscript using a constant integer index')
# if not isinstance(node.ctx, ast.Load):
# _raise_static_error(node, 'Must be load context')
#
# if node.value.id in module.types:
# member_type = module.types[node.value.id]
# else:
# _raise_static_error(node, f'Unrecognized type {node.value.id}')
#
# type_static_array = TypeStaticArray(member_type)
#
# offset = 0
#
# for idx in range(node.slice.value.value):
# static_array_member = TypeStaticArrayMember(idx, offset)
#
# type_static_array.members.append(static_array_member)
# offset += member_type.alloc_size()
#
# key = f'{node.value.id}[{node.slice.value.value}]'
#
# if key not in module.types:
# module.types[key] = type_static_array
#
# return module.types[key]
if isinstance(node, ast.Tuple):
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
raise NotImplementedError('TODO: Broken after new type system')
type_tuple = TypeTuple()
offset = 0
for idx, elt in enumerate(node.elts):
tuple_member = TypeTupleMember(idx, self.visit_type(module, elt), offset)
type_tuple.members.append(tuple_member)
offset += tuple_member.type.alloc_size()
key = type_tuple.render_internal_name()
if key not in module.types:
module.types[key] = type_tuple
constructor = TupleConstructor(type_tuple)
module.functions[constructor.name] = constructor
return module.types[key]
# if not isinstance(node.ctx, ast.Load):
# _raise_static_error(node, 'Must be load context')
#
# type_tuple = TypeTuple()
#
# offset = 0
#
# for idx, elt in enumerate(node.elts):
# tuple_member = TypeTupleMember(idx, self.visit_type(module, elt), offset)
#
# type_tuple.members.append(tuple_member)
# offset += tuple_member.type.alloc_size()
#
# key = type_tuple.render_internal_name()
#
# if key not in module.types:
# module.types[key] = type_tuple
# constructor = TupleConstructor(type_tuple)
# module.functions[constructor.name] = constructor
#
# return module.types[key]
raise NotImplementedError(f'{node} as type')

View File

@ -3,12 +3,12 @@ Type checks and enriches the given ast
"""
from . import ourlang
from .typing import Context, TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeVar
from .typing import Context, TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeVar, from_str
def phasm_type(inp: ourlang.Module) -> None:
module(inp)
def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar':
def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar:
if isinstance(inp, ourlang.ConstantPrimitive):
result = ctx.new_var()
@ -57,7 +57,7 @@ def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar':
raise NotImplementedError(constant, inp)
def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
if isinstance(inp, ourlang.Constant):
return constant(ctx, inp)
@ -109,7 +109,7 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
raise NotImplementedError(expression, inp)
def function(ctx: 'Context', inp: ourlang.Function) -> None:
def function(ctx: Context, inp: ourlang.Function) -> None:
if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn):
raise NotImplementedError('Functions with not just a return statement')
typ = expression(ctx, inp.statements[0].value)
@ -117,10 +117,11 @@ def function(ctx: 'Context', inp: ourlang.Function) -> None:
assert inp.returns_type_var is not None
ctx.unify(inp.returns_type_var, typ)
def module_constant_def(ctx: 'Context', inp: ourlang.ModuleConstantDef) -> None:
inp.type_var = _convert_old_type(ctx, inp.type, inp.name)
def module_constant_def(ctx: Context, inp: ourlang.ModuleConstantDef) -> None:
constant(ctx, inp.constant)
inp.type_var = ctx.new_var()
assert inp.constant.type_var is not None
ctx.unify(inp.type_var, inp.constant.type_var)
@ -128,66 +129,12 @@ def module(inp: ourlang.Module) -> None:
ctx = Context()
for func in inp.functions.values():
func.returns_type_var = _convert_old_type(ctx, func.returns, f'{func.name}.(returns)')
func.returns_type_var = from_str(ctx, func.returns_str, f'{func.name}.(returns)')
for param in func.posonlyargs:
param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}')
param.type_var = from_str(ctx, param.type_str, f'{func.name}.{param.name}')
for cdef in inp.constant_defs.values():
module_constant_def(ctx, cdef)
for func in inp.functions.values():
function(ctx, func)
from . import typing
def _convert_old_type(ctx: Context, inp: typing.TypeBase, location: str) -> TypeVar:
result = ctx.new_var()
if isinstance(inp, typing.TypeUInt8):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if isinstance(inp, typing.TypeUInt32):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if isinstance(inp, typing.TypeUInt64):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if isinstance(inp, typing.TypeInt32):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
if isinstance(inp, typing.TypeInt64):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
if isinstance(inp, typing.TypeFloat32):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_location(location)
return result
if isinstance(inp, typing.TypeFloat64):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_location(location)
return result
raise NotImplementedError(_convert_old_type, inp)

View File

@ -19,88 +19,6 @@ class TypeBase:
"""
raise NotImplementedError(self, 'alloc_size')
class TypeNone(TypeBase):
"""
The None (or Void) type
"""
__slots__ = ()
class TypeBool(TypeBase):
"""
The boolean type
"""
__slots__ = ()
class TypeUInt8(TypeBase):
"""
The Integer type, unsigned and 8 bits wide
Note that under the hood we need to use i32 to represent
these values in expressions. So we need to add some operations
to make sure the math checks out.
So while this does save bytes in memory, it may not actually
speed up or improve your code.
"""
__slots__ = ()
def alloc_size(self) -> int:
return 4 # Int32 under the hood
class TypeUInt32(TypeBase):
"""
The Integer type, unsigned and 32 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 4
class TypeUInt64(TypeBase):
"""
The Integer type, unsigned and 64 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 8
class TypeInt32(TypeBase):
"""
The Integer type, signed and 32 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 4
class TypeInt64(TypeBase):
"""
The Integer type, signed and 64 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 8
class TypeFloat32(TypeBase):
"""
The Float type, 32 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 4
class TypeFloat64(TypeBase):
"""
The Float type, 64 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 8
class TypeBytes(TypeBase):
"""
The bytes type
@ -207,7 +125,7 @@ class TypeStruct(TypeBase):
## NEW STUFF BELOW
# This error can also mean that the type somewhere forgot to write a type
# This error can also mean that the typer somewhere forgot to write a type
# back to the AST. If so, we need to fix the typer.
ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method'
@ -392,7 +310,12 @@ class Context:
return result
def unify(self, l: 'TypeVar', r: 'TypeVar') -> None:
def unify(self, l: Optional[TypeVar], r: Optional[TypeVar]) -> None:
# FIXME: Write method doc, find out why pylint doesn't error
assert l is not None, ASSERTION_ERROR
assert r is not None, ASSERTION_ERROR
assert l.ctx_id != r.ctx_id # Dunno if this'll happen, if so, just return
# Backup some values that we'll overwrite
@ -438,6 +361,8 @@ def simplify(inp: TypeVar) -> Optional[str]:
"""
Simplifies a TypeVar into a string that wasm can work with
and users can recognize
Should round trip with from_str
"""
tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive)
tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth)
@ -473,3 +398,66 @@ def simplify(inp: TypeVar) -> Optional[str]:
return f'f{tc_bits.minb}'
return None
def from_str(ctx: Context, inp: str, location: str) -> TypeVar:
"""
Creates a new TypeVar from the string
Should round trip with simplify
The location is a reference to where you found the string
in the source code.
This could be conidered part of parsing. Though that would give trouble
with the context creation.
"""
result = ctx.new_var()
if inp == 'u8':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if inp == 'u32':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if inp == 'u64':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if inp == 'i32':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
if inp == 'i64':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
if inp == 'f32':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_location(location)
return result
if inp == 'f64':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_location(location)
return result
raise NotImplementedError(from_str, inp)