Adds a separte typing system #3

Closed
jbwdevries wants to merge 18 commits from milner_type_checking into master
6 changed files with 504 additions and 601 deletions
Showing only changes of commit 07c0688d1b - Show all commits

View File

@ -33,7 +33,7 @@ def struct_definition(inp: typing.TypeStruct) -> str:
Render: TypeStruct's definition Render: TypeStruct's definition
""" """
result = f'class {inp.name}:\n' result = f'class {inp.name}:\n'
for mem in inp.members: for mem in inp.members: # TODO: Broken after new type system
raise NotImplementedError('Structs broken after new type system') raise NotImplementedError('Structs broken after new type system')
# result += f' {mem.name}: {type_(mem.type)}\n' # result += f' {mem.name}: {type_(mem.type)}\n'
@ -87,11 +87,12 @@ def expression(inp: ourlang.Expression) -> str:
for arg in inp.arguments for arg in inp.arguments
) )
if isinstance(inp.function, ourlang.StructConstructor): # TODO: Broken after new type system
return f'{inp.function.struct.name}({args})' # if isinstance(inp.function, ourlang.StructConstructor):
# return f'{inp.function.struct.name}({args})'
if isinstance(inp.function, ourlang.TupleConstructor): #
return f'({args}, )' # if isinstance(inp.function, ourlang.TupleConstructor):
# return f'({args}, )'
return f'{inp.function.name}({args})' return f'{inp.function.name}({args})'

View File

@ -1,7 +1,7 @@
""" """
This module contains the code to convert parsed Ourlang into WebAssembly code This module contains the code to convert parsed Ourlang into WebAssembly code
""" """
from typing import List from typing import List, Optional
import struct import struct
@ -14,19 +14,6 @@ from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types from .stdlib import types as stdlib_types
from .wasmgenerator import Generator as WasmGenerator from .wasmgenerator import Generator as WasmGenerator
LOAD_STORE_TYPE_MAP = {
typing.TypeUInt8: 'i32',
typing.TypeUInt32: 'i32',
typing.TypeUInt64: 'i64',
typing.TypeInt32: 'i32',
typing.TypeInt64: 'i64',
typing.TypeFloat32: 'f32',
typing.TypeFloat64: 'f64',
}
"""
When generating code, we sometimes need to load or store simple values
"""
def phasm_compile(inp: ourlang.Module) -> wasm.Module: def phasm_compile(inp: ourlang.Module) -> wasm.Module:
""" """
Public method for compiling a parsed Phasm module into Public method for compiling a parsed Phasm module into
@ -34,42 +21,44 @@ def phasm_compile(inp: ourlang.Module) -> wasm.Module:
""" """
return module(inp) return module(inp)
def type_(inp: typing.TypeBase) -> wasm.WasmType: def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType:
""" """
Compile: type Compile: type
""" """
if isinstance(inp, typing.TypeNone): assert inp is not None, typing.ASSERTION_ERROR
return wasm.WasmTypeNone()
if isinstance(inp, typing.TypeUInt8): mtyp = typing.simplify(inp)
if mtyp == 'u8':
# WebAssembly has only support for 32 and 64 bits # WebAssembly has only support for 32 and 64 bits
# So we need to store more memory per byte # So we need to store more memory per byte
return wasm.WasmTypeInt32() return wasm.WasmTypeInt32()
if isinstance(inp, typing.TypeUInt32): if mtyp == 'u32':
return wasm.WasmTypeInt32() return wasm.WasmTypeInt32()
if isinstance(inp, typing.TypeUInt64): if mtyp == 'u64':
return wasm.WasmTypeInt64() return wasm.WasmTypeInt64()
if isinstance(inp, typing.TypeInt32): if mtyp == 'i32':
return wasm.WasmTypeInt32() return wasm.WasmTypeInt32()
if isinstance(inp, typing.TypeInt64): if mtyp == 'i64':
return wasm.WasmTypeInt64() return wasm.WasmTypeInt64()
if isinstance(inp, typing.TypeFloat32): if mtyp == 'f32':
return wasm.WasmTypeFloat32() return wasm.WasmTypeFloat32()
if isinstance(inp, typing.TypeFloat64): if mtyp == 'f64':
return wasm.WasmTypeFloat64() return wasm.WasmTypeFloat64()
if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)): # TODO: Broken after new type system
# Structs and tuples are passed as pointer # if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)):
# And pointers are i32 # # Structs and tuples are passed as pointer
return wasm.WasmTypeInt32() # # And pointers are i32
# return wasm.WasmTypeInt32()
raise NotImplementedError(type_, inp) raise NotImplementedError(inp, mtyp)
# Operators that work for i32, i64, f32, f64 # Operators that work for i32, i64, f32, f64
OPERATOR_MAP = { OPERATOR_MAP = {
@ -268,47 +257,47 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
# wgn.call(stdlib_types.__subscript_bytes__) # wgn.call(stdlib_types.__subscript_bytes__)
# return # return
if isinstance(inp, ourlang.AccessStructMember): # if isinstance(inp, ourlang.AccessStructMember):
mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) # mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__)
if mtyp is None: # if mtyp is None:
# In the future might extend this by having structs or tuples # # In the future might extend this by having structs or tuples
# as members of struct or tuples # # as members of struct or tuples
raise NotImplementedError(expression, inp, inp.member) # raise NotImplementedError(expression, inp, inp.member)
#
expression(wgn, inp.varref) # expression(wgn, inp.varref)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) # wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
return # return
#
if isinstance(inp, ourlang.AccessTupleMember): # if isinstance(inp, ourlang.AccessTupleMember):
mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) # mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__)
if mtyp is None: # if mtyp is None:
# In the future might extend this by having structs or tuples # # In the future might extend this by having structs or tuples
# as members of struct or tuples # # as members of struct or tuples
raise NotImplementedError(expression, inp, inp.member) # raise NotImplementedError(expression, inp, inp.member)
#
expression(wgn, inp.varref) # expression(wgn, inp.varref)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) # wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
return # return
#
if isinstance(inp, ourlang.AccessStaticArrayMember): # if isinstance(inp, ourlang.AccessStaticArrayMember):
mtyp = LOAD_STORE_TYPE_MAP.get(inp.static_array.member_type.__class__) # mtyp = LOAD_STORE_TYPE_MAP.get(inp.static_array.member_type.__class__)
if mtyp is None: # if mtyp is None:
# In the future might extend this by having structs or tuples # # In the future might extend this by having structs or tuples
# as members of static arrays # # as members of static arrays
raise NotImplementedError(expression, inp, inp.member) # raise NotImplementedError(expression, inp, inp.member)
#
if isinstance(inp.member, typing.TypeStaticArrayMember): # if isinstance(inp.member, typing.TypeStaticArrayMember):
expression(wgn, inp.varref) # expression(wgn, inp.varref)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) # wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
return # return
#
expression(wgn, inp.varref) # expression(wgn, inp.varref)
expression(wgn, inp.member) # expression(wgn, inp.member)
wgn.i32.const(inp.static_array.member_type.alloc_size()) # wgn.i32.const(inp.static_array.member_type.alloc_size())
wgn.i32.mul() # wgn.i32.mul()
wgn.i32.add() # wgn.i32.add()
wgn.add_statement(f'{mtyp}.load') # wgn.add_statement(f'{mtyp}.load')
return # return
if isinstance(inp, ourlang.Fold): if isinstance(inp, ourlang.Fold):
expression_fold(wgn, inp) expression_fold(wgn, inp)
@ -472,7 +461,7 @@ def function_argument(inp: ourlang.FunctionParam) -> wasm.Param:
""" """
Compile: function argument Compile: function argument
""" """
return (inp.name, type_(inp.type), ) return (inp.name, type_var(inp.type_var), )
def import_(inp: ourlang.Function) -> wasm.Import: def import_(inp: ourlang.Function) -> wasm.Import:
""" """
@ -488,7 +477,7 @@ def import_(inp: ourlang.Function) -> wasm.Import:
function_argument(x) function_argument(x)
for x in inp.posonlyargs for x in inp.posonlyargs
], ],
type_(inp.returns) type_var(inp.returns_type_var)
) )
def function(inp: ourlang.Function) -> wasm.Function: def function(inp: ourlang.Function) -> wasm.Function:
@ -499,10 +488,10 @@ def function(inp: ourlang.Function) -> wasm.Function:
wgn = WasmGenerator() wgn = WasmGenerator()
if isinstance(inp, ourlang.TupleConstructor): if False: # TODO: isinstance(inp, ourlang.TupleConstructor):
_generate_tuple_constructor(wgn, inp) pass # _generate_tuple_constructor(wgn, inp)
elif isinstance(inp, ourlang.StructConstructor): elif False: # TODO: isinstance(inp, ourlang.StructConstructor):
_generate_struct_constructor(wgn, inp) pass # _generate_struct_constructor(wgn, inp)
else: else:
for stat in inp.statements: for stat in inp.statements:
statement(wgn, stat) statement(wgn, stat)
@ -518,7 +507,7 @@ def function(inp: ourlang.Function) -> wasm.Function:
(k, v.wasm_type(), ) (k, v.wasm_type(), )
for k, v in wgn.locals.items() for k, v in wgn.locals.items()
], ],
type_(inp.returns), type_var(inp.returns_type_var),
wgn.statements wgn.statements
) )
@ -660,48 +649,49 @@ def module(inp: ourlang.Module) -> wasm.Module:
return result return result
def _generate_tuple_constructor(wgn: WasmGenerator, inp: ourlang.TupleConstructor) -> None: # TODO: Broken after new type system
tmp_var = wgn.temp_var_i32('tuple_adr') # def _generate_tuple_constructor(wgn: WasmGenerator, inp: ourlang.TupleConstructor) -> None:
# tmp_var = wgn.temp_var_i32('tuple_adr')
# Allocated the required amounts of bytes in memory #
wgn.i32.const(inp.tuple.alloc_size()) # # Allocated the required amounts of bytes in memory
wgn.call(stdlib_alloc.__alloc__) # wgn.i32.const(inp.tuple.alloc_size())
wgn.local.set(tmp_var) # wgn.call(stdlib_alloc.__alloc__)
# wgn.local.set(tmp_var)
# Store each member individually #
for member in inp.tuple.members: # # Store each member individually
mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) # for member in inp.tuple.members:
if mtyp is None: # mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__)
# In the future might extend this by having structs or tuples # if mtyp is None:
# as members of struct or tuples # # In the future might extend this by having structs or tuples
raise NotImplementedError(expression, inp, member) # # as members of struct or tuples
# raise NotImplementedError(expression, inp, member)
wgn.local.get(tmp_var) #
wgn.add_statement('local.get', f'$arg{member.idx}') # wgn.local.get(tmp_var)
wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset)) # wgn.add_statement('local.get', f'$arg{member.idx}')
# wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset))
# Return the allocated address #
wgn.local.get(tmp_var) # # Return the allocated address
# wgn.local.get(tmp_var)
def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None: #
tmp_var = wgn.temp_var_i32('struct_adr') # def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None:
# tmp_var = wgn.temp_var_i32('struct_adr')
# Allocated the required amounts of bytes in memory #
wgn.i32.const(inp.struct.alloc_size()) # # Allocated the required amounts of bytes in memory
wgn.call(stdlib_alloc.__alloc__) # wgn.i32.const(inp.struct.alloc_size())
wgn.local.set(tmp_var) # wgn.call(stdlib_alloc.__alloc__)
# wgn.local.set(tmp_var)
# Store each member individually #
for member in inp.struct.members: # # Store each member individually
mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) # for member in inp.struct.members:
if mtyp is None: # mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__)
# In the future might extend this by having structs or tuples # if mtyp is None:
# as members of struct or tuples # # In the future might extend this by having structs or tuples
raise NotImplementedError(expression, inp, member) # # as members of struct or tuples
# raise NotImplementedError(expression, inp, member)
wgn.local.get(tmp_var) #
wgn.add_statement('local.get', f'${member.name}') # wgn.local.get(tmp_var)
wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset)) # wgn.add_statement('local.get', f'${member.name}')
# wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset))
# Return the allocated address #
wgn.local.get(tmp_var) # # Return the allocated address
# wgn.local.get(tmp_var)

View File

@ -11,11 +11,6 @@ WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc',
WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', ) WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', )
from .typing import ( from .typing import (
TypeBase,
TypeNone,
TypeUInt8, TypeUInt32, TypeUInt64,
TypeInt32, TypeInt64,
TypeFloat32, TypeFloat64,
TypeBytes, TypeBytes,
TypeTuple, TypeTupleMember, TypeTuple, TypeTupleMember,
TypeStaticArray, TypeStaticArrayMember, TypeStaticArray, TypeStaticArrayMember,
@ -280,29 +275,29 @@ class FunctionParam:
""" """
A parameter for a Function A parameter for a Function
""" """
__slots__ = ('name', 'type', 'type_var', ) __slots__ = ('name', 'type_str', 'type_var', )
name: str name: str
type: TypeBase type_str: str
type_var: Optional[TypeVar] type_var: Optional[TypeVar]
def __init__(self, name: str, type_: TypeBase) -> None: def __init__(self, name: str, type_str: str) -> None:
self.name = name self.name = name
self.type = type_ self.type_str = type_str
self.type_var = None self.type_var = None
class Function: class Function:
""" """
A function processes input and produces output A function processes input and produces output
""" """
__slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'returns_type_var', 'posonlyargs', ) __slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns_str', 'returns_type_var', 'posonlyargs', )
name: str name: str
lineno: int lineno: int
exported: bool exported: bool
imported: bool imported: bool
statements: List[Statement] statements: List[Statement]
returns: TypeBase returns_str: str
returns_type_var: Optional[TypeVar] returns_type_var: Optional[TypeVar]
posonlyargs: List[FunctionParam] posonlyargs: List[FunctionParam]
@ -312,68 +307,67 @@ class Function:
self.exported = False self.exported = False
self.imported = False self.imported = False
self.statements = [] self.statements = []
self.returns = TypeNone() self.returns_str = 'None'
self.returns_type_var = None self.returns_type_var = None
self.posonlyargs = [] self.posonlyargs = []
class StructConstructor(Function): # TODO: Broken after new type system
""" # class StructConstructor(Function):
The constructor method for a struct # """
# The constructor method for a struct
A function will generated to instantiate a struct. The arguments #
will be the defaults # A function will generated to instantiate a struct. The arguments
""" # will be the defaults
__slots__ = ('struct', ) # """
# __slots__ = ('struct', )
struct: TypeStruct #
# struct: TypeStruct
def __init__(self, struct: TypeStruct) -> None: #
super().__init__(f'@{struct.name}@__init___@', -1) # def __init__(self, struct: TypeStruct) -> None:
# super().__init__(f'@{struct.name}@__init___@', -1)
self.returns = struct #
# self.returns = struct
for mem in struct.members: #
self.posonlyargs.append(FunctionParam(mem.name, mem.type, )) # for mem in struct.members:
# self.posonlyargs.append(FunctionParam(mem.name, mem.type, ))
self.struct = struct #
# self.struct = struct
class TupleConstructor(Function): #
""" # class TupleConstructor(Function):
The constructor method for a tuple # """
""" # The constructor method for a tuple
__slots__ = ('tuple', ) # """
# __slots__ = ('tuple', )
tuple: TypeTuple #
# tuple: TypeTuple
def __init__(self, tuple_: TypeTuple) -> None: #
name = tuple_.render_internal_name() # def __init__(self, tuple_: TypeTuple) -> None:
# name = tuple_.render_internal_name()
super().__init__(f'@{name}@__init___@', -1) #
# super().__init__(f'@{name}@__init___@', -1)
self.returns = tuple_ #
# self.returns = tuple_
for mem in tuple_.members: #
self.posonlyargs.append(FunctionParam(f'arg{mem.idx}', mem.type, )) # for mem in tuple_.members:
# self.posonlyargs.append(FunctionParam(f'arg{mem.idx}', mem.type, ))
self.tuple = tuple_ #
# self.tuple = tuple_
class ModuleConstantDef: class ModuleConstantDef:
""" """
A constant definition within a module A constant definition within a module
""" """
__slots__ = ('name', 'lineno', 'type', 'type_var', 'constant', 'data_block', ) __slots__ = ('name', 'lineno', 'type_var', 'constant', 'data_block', )
name: str name: str
lineno: int lineno: int
type: TypeBase
type_var: Optional[TypeVar] type_var: Optional[TypeVar]
constant: Constant constant: Constant
data_block: Optional['ModuleDataBlock'] data_block: Optional['ModuleDataBlock']
def __init__(self, name: str, lineno: int, type_: TypeBase, constant: Constant, data_block: Optional['ModuleDataBlock']) -> None: def __init__(self, name: str, lineno: int, constant: Constant, data_block: Optional['ModuleDataBlock']) -> None:
self.name = name self.name = name
self.lineno = lineno self.lineno = lineno
self.type = type_
self.type_var = None self.type_var = None
self.constant = constant self.constant = constant
self.data_block = data_block self.data_block = data_block
@ -409,23 +403,11 @@ class Module:
__slots__ = ('data', 'types', 'structs', 'constant_defs', 'functions',) __slots__ = ('data', 'types', 'structs', 'constant_defs', 'functions',)
data: ModuleData data: ModuleData
types: Dict[str, TypeBase]
structs: Dict[str, TypeStruct] structs: Dict[str, TypeStruct]
constant_defs: Dict[str, ModuleConstantDef] constant_defs: Dict[str, ModuleConstantDef]
functions: Dict[str, Function] functions: Dict[str, Function]
def __init__(self) -> None: def __init__(self) -> None:
self.types = {
'None': TypeNone(),
'u8': TypeUInt8(),
'u32': TypeUInt32(),
'u64': TypeUInt64(),
'i32': TypeInt32(),
'i64': TypeInt64(),
'f32': TypeFloat32(),
'f64': TypeFloat64(),
'bytes': TypeBytes(),
}
self.data = ModuleData() self.data = ModuleData()
self.structs = {} self.structs = {}
self.constant_defs = {} self.constant_defs = {}

View File

@ -6,8 +6,6 @@ from typing import Any, Dict, NoReturn, Union
import ast import ast
from .typing import ( from .typing import (
TypeBase,
TypeBytes,
TypeStruct, TypeStruct,
TypeStructMember, TypeStructMember,
TypeTuple, TypeTuple,
@ -16,7 +14,6 @@ from .typing import (
TypeStaticArrayMember, TypeStaticArrayMember,
) )
from . import codestyle
from .exceptions import StaticError from .exceptions import StaticError
from .ourlang import ( from .ourlang import (
WEBASSEMBLY_BUILDIN_FLOAT_OPS, WEBASSEMBLY_BUILDIN_FLOAT_OPS,
@ -30,7 +27,7 @@ from .ourlang import (
ConstantPrimitive, ConstantTuple, ConstantStaticArray, ConstantPrimitive, ConstantTuple, ConstantStaticArray,
FunctionCall, FunctionCall,
StructConstructor, TupleConstructor, # StructConstructor, TupleConstructor,
UnaryOp, VariableReference, UnaryOp, VariableReference,
Fold, ModuleConstantReference, Fold, ModuleConstantReference,
@ -86,15 +83,16 @@ class OurVisitor:
module.constant_defs[res.name] = res module.constant_defs[res.name] = res
if isinstance(res, TypeStruct): # TODO: Broken after type system
if res.name in module.structs: # if isinstance(res, TypeStruct):
raise StaticError( # if res.name in module.structs:
f'{res.name} already defined on line {module.structs[res.name].lineno}' # raise StaticError(
) # f'{res.name} already defined on line {module.structs[res.name].lineno}'
# )
module.structs[res.name] = res #
constructor = StructConstructor(res) # module.structs[res.name] = res
module.functions[constructor.name] = constructor # constructor = StructConstructor(res)
# module.functions[constructor.name] = constructor
if isinstance(res, Function): if isinstance(res, Function):
if res.name in module.functions: if res.name in module.functions:
@ -158,7 +156,7 @@ class OurVisitor:
function.imported = True function.imported = True
if node.returns: if node.returns:
function.returns = self.visit_type(module, node.returns) function.returns_str = self.visit_type(module, node.returns)
_not_implemented(not node.type_comment, 'FunctionDef.type_comment') _not_implemented(not node.type_comment, 'FunctionDef.type_comment')
@ -186,6 +184,7 @@ class OurVisitor:
if stmt.simple != 1: if stmt.simple != 1:
raise NotImplementedError('Class with non-simple arguments') raise NotImplementedError('Class with non-simple arguments')
raise NotImplementedError('TODO: Broken after new type system')
member = TypeStructMember(stmt.target.id, self.visit_type(module, stmt.annotation), offset) member = TypeStructMember(stmt.target.id, self.visit_type(module, stmt.annotation), offset)
struct.members.append(member) struct.members.append(member)
@ -199,74 +198,72 @@ class OurVisitor:
if not isinstance(node.target.ctx, ast.Store): if not isinstance(node.target.ctx, ast.Store):
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
exp_type = self.visit_type(module, node.annotation)
if isinstance(node.value, ast.Constant): if isinstance(node.value, ast.Constant):
return ModuleConstantDef( return ModuleConstantDef(
node.target.id, node.target.id,
node.lineno, node.lineno,
exp_type,
self.visit_Module_Constant(module, node.value), self.visit_Module_Constant(module, node.value),
None, None,
) )
if isinstance(exp_type, TypeTuple): raise NotImplementedError('TODO: Broken after new typing system')
if not isinstance(node.value, ast.Tuple):
_raise_static_error(node, 'Must be tuple')
if len(exp_type.members) != len(node.value.elts): # if isinstance(exp_type, TypeTuple):
_raise_static_error(node, 'Invalid number of tuple values') # if not isinstance(node.value, ast.Tuple):
# _raise_static_error(node, 'Must be tuple')
tuple_data = [ #
self.visit_Module_Constant(module, arg_node) # if len(exp_type.members) != len(node.value.elts):
for arg_node, mem in zip(node.value.elts, exp_type.members) # _raise_static_error(node, 'Invalid number of tuple values')
if isinstance(arg_node, ast.Constant) #
] # tuple_data = [
if len(exp_type.members) != len(tuple_data): # self.visit_Module_Constant(module, arg_node)
_raise_static_error(node, 'Tuple arguments must be constants') # for arg_node, mem in zip(node.value.elts, exp_type.members)
# if isinstance(arg_node, ast.Constant)
# Allocate the data # ]
data_block = ModuleDataBlock(tuple_data) # if len(exp_type.members) != len(tuple_data):
module.data.blocks.append(data_block) # _raise_static_error(node, 'Tuple arguments must be constants')
#
# Then return the constant as a pointer # # Allocate the data
return ModuleConstantDef( # data_block = ModuleDataBlock(tuple_data)
node.target.id, # module.data.blocks.append(data_block)
node.lineno, #
exp_type, # # Then return the constant as a pointer
ConstantTuple(tuple_data), # return ModuleConstantDef(
data_block, # node.target.id,
) # node.lineno,
# exp_type,
if isinstance(exp_type, TypeStaticArray): # ConstantTuple(tuple_data),
if not isinstance(node.value, ast.Tuple): # data_block,
_raise_static_error(node, 'Must be static array') # )
#
if len(exp_type.members) != len(node.value.elts): # if isinstance(exp_type, TypeStaticArray):
_raise_static_error(node, 'Invalid number of static array values') # if not isinstance(node.value, ast.Tuple):
# _raise_static_error(node, 'Must be static array')
static_array_data = [ #
self.visit_Module_Constant(module, arg_node) # if len(exp_type.members) != len(node.value.elts):
for arg_node in node.value.elts # _raise_static_error(node, 'Invalid number of static array values')
if isinstance(arg_node, ast.Constant) #
] # static_array_data = [
if len(exp_type.members) != len(static_array_data): # self.visit_Module_Constant(module, arg_node)
_raise_static_error(node, 'Static array arguments must be constants') # for arg_node in node.value.elts
# if isinstance(arg_node, ast.Constant)
# Allocate the data # ]
data_block = ModuleDataBlock(static_array_data) # if len(exp_type.members) != len(static_array_data):
module.data.blocks.append(data_block) # _raise_static_error(node, 'Static array arguments must be constants')
#
# Then return the constant as a pointer # # Allocate the data
return ModuleConstantDef( # data_block = ModuleDataBlock(static_array_data)
node.target.id, # module.data.blocks.append(data_block)
node.lineno, #
exp_type, # # Then return the constant as a pointer
ConstantStaticArray(static_array_data), # return ModuleConstantDef(
data_block, # node.target.id,
) # node.lineno,
# ConstantStaticArray(static_array_data),
raise NotImplementedError(f'{node} on Module AnnAssign') # data_block,
# )
#
# raise NotImplementedError(f'{node} on Module AnnAssign')
def visit_Module_stmt(self, module: Module, node: ast.stmt) -> None: def visit_Module_stmt(self, module: Module, node: ast.stmt) -> None:
if isinstance(node, ast.FunctionDef): if isinstance(node, ast.FunctionDef):
@ -301,12 +298,12 @@ class OurVisitor:
_raise_static_error(node, 'Return must have an argument') _raise_static_error(node, 'Return must have an argument')
return StatementReturn( return StatementReturn(
self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.value) self.visit_Module_FunctionDef_expr(module, function, our_locals, node.value)
) )
if isinstance(node, ast.If): if isinstance(node, ast.If):
result = StatementIf( result = StatementIf(
self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.test) self.visit_Module_FunctionDef_expr(module, function, our_locals, node.test)
) )
for stmt in node.body: for stmt in node.body:
@ -326,7 +323,7 @@ class OurVisitor:
raise NotImplementedError(f'{node} as stmt in FunctionDef') raise NotImplementedError(f'{node} as stmt in FunctionDef')
def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.expr) -> Expression: def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, node: ast.expr) -> Expression:
if isinstance(node, ast.BinOp): if isinstance(node, ast.BinOp):
if isinstance(node.op, ast.Add): if isinstance(node.op, ast.Add):
operator = '+' operator = '+'
@ -352,8 +349,8 @@ class OurVisitor:
return BinaryOp( return BinaryOp(
operator, operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right),
) )
if isinstance(node, ast.UnaryOp): if isinstance(node, ast.UnaryOp):
@ -366,7 +363,7 @@ class OurVisitor:
return UnaryOp( return UnaryOp(
operator, operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.operand),
) )
if isinstance(node, ast.Compare): if isinstance(node, ast.Compare):
@ -387,12 +384,12 @@ class OurVisitor:
return BinaryOp( return BinaryOp(
operator, operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]),
) )
if isinstance(node, ast.Call): if isinstance(node, ast.Call):
return self.visit_Module_FunctionDef_Call(module, function, our_locals, exp_type, node) return self.visit_Module_FunctionDef_Call(module, function, our_locals, node)
if isinstance(node, ast.Constant): if isinstance(node, ast.Constant):
return self.visit_Module_Constant( return self.visit_Module_Constant(
@ -426,29 +423,29 @@ class OurVisitor:
if isinstance(node, ast.Tuple): if isinstance(node, ast.Tuple):
raise NotImplementedError('TODO: Broken after new type system') raise NotImplementedError('TODO: Broken after new type system')
if not isinstance(node.ctx, ast.Load): # if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context') # _raise_static_error(node, 'Must be load context')
#
if isinstance(exp_type, TypeTuple): # if isinstance(exp_type, TypeTuple):
if len(exp_type.members) != len(node.elts): # if len(exp_type.members) != len(node.elts):
_raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given') # _raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given')
#
tuple_constructor = TupleConstructor(exp_type) # tuple_constructor = TupleConstructor(exp_type)
#
func = module.functions[tuple_constructor.name] # func = module.functions[tuple_constructor.name]
#
result = FunctionCall(func) # result = FunctionCall(func)
result.arguments = [ # result.arguments = [
self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node) # self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node)
for arg_node, mem in zip(node.elts, exp_type.members) # for arg_node, mem in zip(node.elts, exp_type.members)
] # ]
return result # return result
#
_raise_static_error(node, f'Expression is expecting a {codestyle.type_(exp_type)}, not a tuple') # _raise_static_error(node, f'Expression is expecting a {codestyle.type_(exp_type)}, not a tuple')
raise NotImplementedError(f'{node} as expr in FunctionDef') raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]: def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]:
if node.keywords: if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet? _raise_static_error(node, 'Keyword calling not supported') # Yet?
@ -458,17 +455,18 @@ class OurVisitor:
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
if node.func.id in module.structs: if node.func.id in module.structs:
struct = module.structs[node.func.id] raise NotImplementedError('TODO: Broken after new type system')
struct_constructor = StructConstructor(struct) # struct = module.structs[node.func.id]
# struct_constructor = StructConstructor(struct)
func = module.functions[struct_constructor.name] #
# func = module.functions[struct_constructor.name]
elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS:
if 1 != len(node.args): if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
return UnaryOp( return UnaryOp(
'sqrt', 'sqrt',
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]),
) )
elif node.func.id == 'u32': elif node.func.id == 'u32':
if 1 != len(node.args): if 1 != len(node.args):
@ -478,7 +476,7 @@ class OurVisitor:
return UnaryOp( return UnaryOp(
'cast', 'cast',
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['u8'], node.args[0]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]),
) )
elif node.func.id == 'len': elif node.func.id == 'len':
if 1 != len(node.args): if 1 != len(node.args):
@ -486,7 +484,7 @@ class OurVisitor:
return UnaryOp( return UnaryOp(
'len', 'len',
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]),
) )
elif node.func.id == 'foldl': elif node.func.id == 'foldl':
# TODO: This should a much more generic function! # TODO: This should a much more generic function!
@ -511,20 +509,11 @@ class OurVisitor:
raise NotImplementedError('TODO: Broken after new type system') raise NotImplementedError('TODO: Broken after new type system')
if exp_type.__class__ != func.returns.__class__:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
if func.returns.__class__ != func.posonlyargs[0].type.__class__:
_raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0].type)}')
if module.types['u8'].__class__ != func.posonlyargs[1].type.__class__:
_raise_static_error(node, 'Only folding over bytes (u8) is supported at this time')
return Fold( return Fold(
Fold.Direction.LEFT, Fold.Direction.LEFT,
func, func,
self.visit_Module_FunctionDef_expr(module, function, our_locals, func.returns, node.args[1]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[2]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
) )
else: else:
if node.func.id not in module.functions: if node.func.id not in module.functions:
@ -532,20 +521,18 @@ class OurVisitor:
func = module.functions[node.func.id] func = module.functions[node.func.id]
# if func.returns != exp_type:
# _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
if len(func.posonlyargs) != len(node.args): if len(func.posonlyargs) != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given') _raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given')
result = FunctionCall(func) result = FunctionCall(func)
result.arguments.extend( result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, param.type, arg_expr) self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
for arg_expr, param in zip(node.args, func.posonlyargs) for arg_expr, param in zip(node.args, func.posonlyargs)
) )
return result return result
def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression: def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression:
raise NotImplementedError('Broken after new type system')
del module del module
del function del function
@ -574,87 +561,89 @@ class OurVisitor:
) )
def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression: def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression:
if not isinstance(node.value, ast.Name): raise NotImplementedError('TODO: Broken after new type system')
_raise_static_error(node, 'Must reference a name')
if not isinstance(node.slice, ast.Index): # if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must subscript using an index') # _raise_static_error(node, 'Must reference a name')
#
if not isinstance(node.ctx, ast.Load): # if not isinstance(node.slice, ast.Index):
_raise_static_error(node, 'Must be load context') # _raise_static_error(node, 'Must subscript using an index')
#
varref: Union[ModuleConstantReference, VariableReference] # if not isinstance(node.ctx, ast.Load):
if node.value.id in our_locals: # _raise_static_error(node, 'Must be load context')
param = our_locals[node.value.id] #
node_typ = param.type # varref: Union[ModuleConstantReference, VariableReference]
varref = VariableReference(param) # if node.value.id in our_locals:
elif node.value.id in module.constant_defs: # param = our_locals[node.value.id]
constant_def = module.constant_defs[node.value.id] # node_typ = param.type
node_typ = constant_def.type # varref = VariableReference(param)
varref = ModuleConstantReference(constant_def) # elif node.value.id in module.constant_defs:
else: # constant_def = module.constant_defs[node.value.id]
_raise_static_error(node, f'Undefined variable {node.value.id}') # node_typ = constant_def.type
# varref = ModuleConstantReference(constant_def)
slice_expr = self.visit_Module_FunctionDef_expr( # else:
module, function, our_locals, module.types['u32'], node.slice.value, # _raise_static_error(node, f'Undefined variable {node.value.id}')
) #
# slice_expr = self.visit_Module_FunctionDef_expr(
if isinstance(node_typ, TypeBytes): # module, function, our_locals, node.slice.value,
if isinstance(varref, ModuleConstantReference): # )
raise NotImplementedError(f'{node} from module constant') #
# if isinstance(node_typ, TypeBytes):
return AccessBytesIndex( # if isinstance(varref, ModuleConstantReference):
varref, # raise NotImplementedError(f'{node} from module constant')
slice_expr, #
) # return AccessBytesIndex(
# varref,
if isinstance(node_typ, TypeTuple): # slice_expr,
if not isinstance(slice_expr, ConstantPrimitive): # )
_raise_static_error(node, 'Must subscript using a constant index') #
# if isinstance(node_typ, TypeTuple):
idx = slice_expr.value # if not isinstance(slice_expr, ConstantPrimitive):
# _raise_static_error(node, 'Must subscript using a constant index')
if not isinstance(idx, int): #
_raise_static_error(node, 'Must subscript using a constant integer index') # idx = slice_expr.value
#
if not (0 <= idx < len(node_typ.members)): # if not isinstance(idx, int):
_raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') # _raise_static_error(node, 'Must subscript using a constant integer index')
#
tuple_member = node_typ.members[idx] # if not (0 <= idx < len(node_typ.members)):
# _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}')
if isinstance(varref, ModuleConstantReference): #
raise NotImplementedError(f'{node} from module constant') # tuple_member = node_typ.members[idx]
#
return AccessTupleMember( # if isinstance(varref, ModuleConstantReference):
varref, # raise NotImplementedError(f'{node} from module constant')
tuple_member, #
) # return AccessTupleMember(
# varref,
if isinstance(node_typ, TypeStaticArray): # tuple_member,
if not isinstance(slice_expr, ConstantPrimitive): # )
return AccessStaticArrayMember( #
varref, # if isinstance(node_typ, TypeStaticArray):
node_typ, # if not isinstance(slice_expr, ConstantPrimitive):
slice_expr, # return AccessStaticArrayMember(
) # varref,
# node_typ,
idx = slice_expr.value # slice_expr,
# )
if not isinstance(idx, int): #
_raise_static_error(node, 'Must subscript using an integer index') # idx = slice_expr.value
#
if not (0 <= idx < len(node_typ.members)): # if not isinstance(idx, int):
_raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}') # _raise_static_error(node, 'Must subscript using an integer index')
#
static_array_member = node_typ.members[idx] # if not (0 <= idx < len(node_typ.members)):
# _raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}')
return AccessStaticArrayMember( #
varref, # static_array_member = node_typ.members[idx]
node_typ, #
static_array_member, # return AccessStaticArrayMember(
) # varref,
# node_typ,
_raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}') # static_array_member,
# )
#
# _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}')
def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive: def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive:
del module del module
@ -666,10 +655,10 @@ class OurVisitor:
raise NotImplementedError(f'{node.value} as constant') raise NotImplementedError(f'{node.value} as constant')
def visit_type(self, module: Module, node: ast.expr) -> TypeBase: def visit_type(self, module: Module, node: ast.expr) -> str:
if isinstance(node, ast.Constant): if isinstance(node, ast.Constant):
if node.value is None: if node.value is None:
return module.types['None'] return 'None'
_raise_static_error(node, f'Unrecognized type {node.value}') _raise_static_error(node, f'Unrecognized type {node.value}')
@ -677,8 +666,10 @@ class OurVisitor:
if not isinstance(node.ctx, ast.Load): if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
if node.id in module.types: if node.id in ('u8', 'u32', 'u64', 'i32', 'i64', 'f32', 'f64'): # FIXME: Source this list somewhere
return module.types[node.id] return node.id
raise NotImplementedError('TODO: Broken after type system')
if node.id in module.structs: if node.id in module.structs:
return module.structs[node.id] return module.structs[node.id]
@ -686,61 +677,65 @@ class OurVisitor:
_raise_static_error(node, f'Unrecognized type {node.id}') _raise_static_error(node, f'Unrecognized type {node.id}')
if isinstance(node, ast.Subscript): if isinstance(node, ast.Subscript):
if not isinstance(node.value, ast.Name): raise NotImplementedError('TODO: Broken after new type system')
_raise_static_error(node, 'Must be name')
if not isinstance(node.slice, ast.Index):
_raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.slice.value, ast.Constant):
_raise_static_error(node, 'Must subscript using a constant index')
if not isinstance(node.slice.value.value, int):
_raise_static_error(node, 'Must subscript using a constant integer index')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if node.value.id in module.types: # if not isinstance(node.value, ast.Name):
member_type = module.types[node.value.id] # _raise_static_error(node, 'Must be name')
else: # if not isinstance(node.slice, ast.Index):
_raise_static_error(node, f'Unrecognized type {node.value.id}') # _raise_static_error(node, 'Must subscript using an index')
# if not isinstance(node.slice.value, ast.Constant):
type_static_array = TypeStaticArray(member_type) # _raise_static_error(node, 'Must subscript using a constant index')
# if not isinstance(node.slice.value.value, int):
offset = 0 # _raise_static_error(node, 'Must subscript using a constant integer index')
# if not isinstance(node.ctx, ast.Load):
for idx in range(node.slice.value.value): # _raise_static_error(node, 'Must be load context')
static_array_member = TypeStaticArrayMember(idx, offset) #
# if node.value.id in module.types:
type_static_array.members.append(static_array_member) # member_type = module.types[node.value.id]
offset += member_type.alloc_size() # else:
# _raise_static_error(node, f'Unrecognized type {node.value.id}')
key = f'{node.value.id}[{node.slice.value.value}]' #
# type_static_array = TypeStaticArray(member_type)
if key not in module.types: #
module.types[key] = type_static_array # offset = 0
#
return module.types[key] # for idx in range(node.slice.value.value):
# static_array_member = TypeStaticArrayMember(idx, offset)
#
# type_static_array.members.append(static_array_member)
# offset += member_type.alloc_size()
#
# key = f'{node.value.id}[{node.slice.value.value}]'
#
# if key not in module.types:
# module.types[key] = type_static_array
#
# return module.types[key]
if isinstance(node, ast.Tuple): if isinstance(node, ast.Tuple):
if not isinstance(node.ctx, ast.Load): raise NotImplementedError('TODO: Broken after new type system')
_raise_static_error(node, 'Must be load context')
type_tuple = TypeTuple() # if not isinstance(node.ctx, ast.Load):
# _raise_static_error(node, 'Must be load context')
offset = 0 #
# type_tuple = TypeTuple()
for idx, elt in enumerate(node.elts): #
tuple_member = TypeTupleMember(idx, self.visit_type(module, elt), offset) # offset = 0
#
type_tuple.members.append(tuple_member) # for idx, elt in enumerate(node.elts):
offset += tuple_member.type.alloc_size() # tuple_member = TypeTupleMember(idx, self.visit_type(module, elt), offset)
#
key = type_tuple.render_internal_name() # type_tuple.members.append(tuple_member)
# offset += tuple_member.type.alloc_size()
if key not in module.types: #
module.types[key] = type_tuple # key = type_tuple.render_internal_name()
constructor = TupleConstructor(type_tuple) #
module.functions[constructor.name] = constructor # if key not in module.types:
# module.types[key] = type_tuple
return module.types[key] # constructor = TupleConstructor(type_tuple)
# module.functions[constructor.name] = constructor
#
# return module.types[key]
raise NotImplementedError(f'{node} as type') raise NotImplementedError(f'{node} as type')

View File

@ -3,12 +3,12 @@ Type checks and enriches the given ast
""" """
from . import ourlang from . import ourlang
from .typing import Context, TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeVar from .typing import Context, TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeVar, from_str
def phasm_type(inp: ourlang.Module) -> None: def phasm_type(inp: ourlang.Module) -> None:
module(inp) module(inp)
def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar': def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar:
if isinstance(inp, ourlang.ConstantPrimitive): if isinstance(inp, ourlang.ConstantPrimitive):
result = ctx.new_var() result = ctx.new_var()
@ -57,7 +57,7 @@ def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar':
raise NotImplementedError(constant, inp) raise NotImplementedError(constant, inp)
def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
if isinstance(inp, ourlang.Constant): if isinstance(inp, ourlang.Constant):
return constant(ctx, inp) return constant(ctx, inp)
@ -109,7 +109,7 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def function(ctx: 'Context', inp: ourlang.Function) -> None: def function(ctx: Context, inp: ourlang.Function) -> None:
if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn): if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn):
raise NotImplementedError('Functions with not just a return statement') raise NotImplementedError('Functions with not just a return statement')
typ = expression(ctx, inp.statements[0].value) typ = expression(ctx, inp.statements[0].value)
@ -117,10 +117,11 @@ def function(ctx: 'Context', inp: ourlang.Function) -> None:
assert inp.returns_type_var is not None assert inp.returns_type_var is not None
ctx.unify(inp.returns_type_var, typ) ctx.unify(inp.returns_type_var, typ)
def module_constant_def(ctx: 'Context', inp: ourlang.ModuleConstantDef) -> None: def module_constant_def(ctx: Context, inp: ourlang.ModuleConstantDef) -> None:
inp.type_var = _convert_old_type(ctx, inp.type, inp.name)
constant(ctx, inp.constant) constant(ctx, inp.constant)
inp.type_var = ctx.new_var()
assert inp.constant.type_var is not None assert inp.constant.type_var is not None
ctx.unify(inp.type_var, inp.constant.type_var) ctx.unify(inp.type_var, inp.constant.type_var)
@ -128,66 +129,12 @@ def module(inp: ourlang.Module) -> None:
ctx = Context() ctx = Context()
for func in inp.functions.values(): for func in inp.functions.values():
func.returns_type_var = _convert_old_type(ctx, func.returns, f'{func.name}.(returns)') func.returns_type_var = from_str(ctx, func.returns_str, f'{func.name}.(returns)')
for param in func.posonlyargs: for param in func.posonlyargs:
param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}') param.type_var = from_str(ctx, param.type_str, f'{func.name}.{param.name}')
for cdef in inp.constant_defs.values(): for cdef in inp.constant_defs.values():
module_constant_def(ctx, cdef) module_constant_def(ctx, cdef)
for func in inp.functions.values(): for func in inp.functions.values():
function(ctx, func) function(ctx, func)
from . import typing
def _convert_old_type(ctx: Context, inp: typing.TypeBase, location: str) -> TypeVar:
result = ctx.new_var()
if isinstance(inp, typing.TypeUInt8):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if isinstance(inp, typing.TypeUInt32):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if isinstance(inp, typing.TypeUInt64):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if isinstance(inp, typing.TypeInt32):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
if isinstance(inp, typing.TypeInt64):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
if isinstance(inp, typing.TypeFloat32):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_location(location)
return result
if isinstance(inp, typing.TypeFloat64):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_location(location)
return result
raise NotImplementedError(_convert_old_type, inp)

View File

@ -19,88 +19,6 @@ class TypeBase:
""" """
raise NotImplementedError(self, 'alloc_size') raise NotImplementedError(self, 'alloc_size')
class TypeNone(TypeBase):
"""
The None (or Void) type
"""
__slots__ = ()
class TypeBool(TypeBase):
"""
The boolean type
"""
__slots__ = ()
class TypeUInt8(TypeBase):
"""
The Integer type, unsigned and 8 bits wide
Note that under the hood we need to use i32 to represent
these values in expressions. So we need to add some operations
to make sure the math checks out.
So while this does save bytes in memory, it may not actually
speed up or improve your code.
"""
__slots__ = ()
def alloc_size(self) -> int:
return 4 # Int32 under the hood
class TypeUInt32(TypeBase):
"""
The Integer type, unsigned and 32 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 4
class TypeUInt64(TypeBase):
"""
The Integer type, unsigned and 64 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 8
class TypeInt32(TypeBase):
"""
The Integer type, signed and 32 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 4
class TypeInt64(TypeBase):
"""
The Integer type, signed and 64 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 8
class TypeFloat32(TypeBase):
"""
The Float type, 32 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 4
class TypeFloat64(TypeBase):
"""
The Float type, 64 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 8
class TypeBytes(TypeBase): class TypeBytes(TypeBase):
""" """
The bytes type The bytes type
@ -207,7 +125,7 @@ class TypeStruct(TypeBase):
## NEW STUFF BELOW ## NEW STUFF BELOW
# This error can also mean that the type somewhere forgot to write a type # This error can also mean that the typer somewhere forgot to write a type
# back to the AST. If so, we need to fix the typer. # back to the AST. If so, we need to fix the typer.
ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method' ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method'
@ -392,7 +310,12 @@ class Context:
return result return result
def unify(self, l: 'TypeVar', r: 'TypeVar') -> None: def unify(self, l: Optional[TypeVar], r: Optional[TypeVar]) -> None:
# FIXME: Write method doc, find out why pylint doesn't error
assert l is not None, ASSERTION_ERROR
assert r is not None, ASSERTION_ERROR
assert l.ctx_id != r.ctx_id # Dunno if this'll happen, if so, just return assert l.ctx_id != r.ctx_id # Dunno if this'll happen, if so, just return
# Backup some values that we'll overwrite # Backup some values that we'll overwrite
@ -438,6 +361,8 @@ def simplify(inp: TypeVar) -> Optional[str]:
""" """
Simplifies a TypeVar into a string that wasm can work with Simplifies a TypeVar into a string that wasm can work with
and users can recognize and users can recognize
Should round trip with from_str
""" """
tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive) tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive)
tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth) tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth)
@ -473,3 +398,66 @@ def simplify(inp: TypeVar) -> Optional[str]:
return f'f{tc_bits.minb}' return f'f{tc_bits.minb}'
return None return None
def from_str(ctx: Context, inp: str, location: str) -> TypeVar:
"""
Creates a new TypeVar from the string
Should round trip with simplify
The location is a reference to where you found the string
in the source code.
This could be conidered part of parsing. Though that would give trouble
with the context creation.
"""
result = ctx.new_var()
if inp == 'u8':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if inp == 'u32':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if inp == 'u64':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if inp == 'i32':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
if inp == 'i64':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
if inp == 'f32':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_location(location)
return result
if inp == 'f64':
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_location(location)
return result
raise NotImplementedError(from_str, inp)