Adds a separte typing system #3

Closed
jbwdevries wants to merge 18 commits from milner_type_checking into master
31 changed files with 2216 additions and 1743 deletions

View File

@ -3,7 +3,7 @@ This module generates source code based on the parsed AST
It's intented to be a "any color, as long as it's black" kind of renderer It's intented to be a "any color, as long as it's black" kind of renderer
""" """
from typing import Generator from typing import Generator, Optional
from . import ourlang from . import ourlang
from . import typing from . import typing
@ -16,63 +16,26 @@ def phasm_render(inp: ourlang.Module) -> str:
Statements = Generator[str, None, None] Statements = Generator[str, None, None]
def type_(inp: typing.TypeBase) -> str: def type_var(inp: Optional[typing.TypeVar]) -> str:
""" """
Render: Type (name) Render: type's name
""" """
if isinstance(inp, typing.TypeNone): assert inp is not None, typing.ASSERTION_ERROR
return 'None'
if isinstance(inp, typing.TypeBool): mtyp = typing.simplify(inp)
return 'bool' if mtyp is None:
raise NotImplementedError(f'Rendering type {inp}')
if isinstance(inp, typing.TypeUInt8): return mtyp
return 'u8'
if isinstance(inp, typing.TypeUInt32):
return 'u32'
if isinstance(inp, typing.TypeUInt64):
return 'u64'
if isinstance(inp, typing.TypeInt32):
return 'i32'
if isinstance(inp, typing.TypeInt64):
return 'i64'
if isinstance(inp, typing.TypeFloat32):
return 'f32'
if isinstance(inp, typing.TypeFloat64):
return 'f64'
if isinstance(inp, typing.TypeBytes):
return 'bytes'
if isinstance(inp, typing.TypeTuple):
mems = ', '.join(
type_(x.type)
for x in inp.members
)
return f'({mems}, )'
if isinstance(inp, typing.TypeStaticArray):
return f'{type_(inp.member_type)}[{len(inp.members)}]'
if isinstance(inp, typing.TypeStruct):
return inp.name
raise NotImplementedError(type_, inp)
def struct_definition(inp: typing.TypeStruct) -> str: def struct_definition(inp: typing.TypeStruct) -> str:
""" """
Render: TypeStruct's definition Render: TypeStruct's definition
""" """
result = f'class {inp.name}:\n' result = f'class {inp.name}:\n'
for mem in inp.members: for mem in inp.members: # TODO: Broken after new type system
result += f' {mem.name}: {type_(mem.type)}\n' raise NotImplementedError('Structs broken after new type system')
# result += f' {mem.name}: {type_(mem.type)}\n'
return result return result
@ -80,40 +43,38 @@ def constant_definition(inp: ourlang.ModuleConstantDef) -> str:
""" """
Render: Module Constant's definition Render: Module Constant's definition
""" """
return f'{inp.name}: {type_(inp.type)} = {expression(inp.constant)}\n' return f'{inp.name}: {type_var(inp.type_var)} = {expression(inp.constant)}\n'
def expression(inp: ourlang.Expression) -> str: def expression(inp: ourlang.Expression) -> str:
""" """
Render: A Phasm expression Render: A Phasm expression
""" """
if isinstance(inp, ( if isinstance(inp, ourlang.ConstantPrimitive):
ourlang.ConstantUInt8, ourlang.ConstantUInt32, ourlang.ConstantUInt64, # Floats might not round trip if the original constant
ourlang.ConstantInt32, ourlang.ConstantInt64,
)):
return str(inp.value)
if isinstance(inp, (ourlang.ConstantFloat32, ourlang.ConstantFloat64, )):
# These might not round trip if the original constant
# could not fit in the given float type # could not fit in the given float type
return str(inp.value) return str(inp.value)
if isinstance(inp, (ourlang.ConstantTuple, ourlang.ConstantStaticArray, )): if isinstance(inp, ourlang.ConstantTuple):
return '(' + ', '.join( return '(' + ', '.join(
expression(x) expression(x)
for x in inp.value for x in inp.value
) + ', )' ) + ', )'
if isinstance(inp, ourlang.VariableReference): if isinstance(inp, ourlang.VariableReference):
return str(inp.name) return str(inp.variable.name)
if isinstance(inp, ourlang.UnaryOp): if isinstance(inp, ourlang.UnaryOp):
if ( if (
inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS
or inp.operator in ourlang.WEBASSEMBLY_BUILDIN_BYTES_OPS): or inp.operator in ourlang.WEBASSEMBLY_BUILTIN_BYTES_OPS):
return f'{inp.operator}({expression(inp.right)})' return f'{inp.operator}({expression(inp.right)})'
if inp.operator == 'cast': if inp.operator == 'cast':
return f'{type_(inp.type)}({expression(inp.right)})' mtyp = type_var(inp.type_var)
if mtyp is None:
raise NotImplementedError(f'Casting to type {inp.type_var}')
return f'{mtyp}({expression(inp.right)})'
return f'{inp.operator}{expression(inp.right)}' return f'{inp.operator}{expression(inp.right)}'
@ -126,33 +87,29 @@ def expression(inp: ourlang.Expression) -> str:
for arg in inp.arguments for arg in inp.arguments
) )
if isinstance(inp.function, ourlang.StructConstructor): # TODO: Broken after new type system
return f'{inp.function.struct.name}({args})' # if isinstance(inp.function, ourlang.StructConstructor):
# return f'{inp.function.struct.name}({args})'
if isinstance(inp.function, ourlang.TupleConstructor): #
return f'({args}, )' # if isinstance(inp.function, ourlang.TupleConstructor):
# return f'({args}, )'
return f'{inp.function.name}({args})' return f'{inp.function.name}({args})'
if isinstance(inp, ourlang.AccessBytesIndex): if isinstance(inp, ourlang.Subscript):
return f'{expression(inp.varref)}[{expression(inp.index)}]' varref = expression(inp.varref)
index = expression(inp.index)
if isinstance(inp, ourlang.AccessStructMember): return f'{varref}[{index}]'
return f'{expression(inp.varref)}.{inp.member.name}'
if isinstance(inp, (ourlang.AccessTupleMember, ourlang.AccessStaticArrayMember, )): # TODO: Broken after new type system
if isinstance(inp.member, ourlang.Expression): # if isinstance(inp, ourlang.AccessStructMember):
return f'{expression(inp.varref)}[{expression(inp.member)}]' # return f'{expression(inp.varref)}.{inp.member.name}'
return f'{expression(inp.varref)}[{inp.member.idx}]'
if isinstance(inp, ourlang.Fold): if isinstance(inp, ourlang.Fold):
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr' fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'
return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})' return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})'
if isinstance(inp, ourlang.ModuleConstantReference):
return inp.definition.name
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def statement(inp: ourlang.Statement) -> Statements: def statement(inp: ourlang.Statement) -> Statements:
@ -193,11 +150,11 @@ def function(inp: ourlang.Function) -> str:
result += '@imported\n' result += '@imported\n'
args = ', '.join( args = ', '.join(
f'{x}: {type_(y)}' f'{p.name}: {type_var(p.type_var)}'
for x, y in inp.posonlyargs for p in inp.posonlyargs
) )
result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n' result += f'def {inp.name}({args}) -> {type_var(inp.returns_type_var)}:\n'
if inp.imported: if inp.imported:
result += ' pass\n' result += ' pass\n'
@ -227,7 +184,7 @@ def module(inp: ourlang.Module) -> str:
for func in inp.functions.values(): for func in inp.functions.values():
if func.lineno < 0: if func.lineno < 0:
# Buildin (-2) or auto generated (-1) # Builtin (-2) or auto generated (-1)
continue continue
if result: if result:

View File

@ -1,6 +1,8 @@
""" """
This module contains the code to convert parsed Ourlang into WebAssembly code This module contains the code to convert parsed Ourlang into WebAssembly code
""" """
from typing import List, Optional
import struct import struct
from . import codestyle from . import codestyle
@ -12,19 +14,6 @@ from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types from .stdlib import types as stdlib_types
from .wasmgenerator import Generator as WasmGenerator from .wasmgenerator import Generator as WasmGenerator
LOAD_STORE_TYPE_MAP = {
typing.TypeUInt8: 'i32',
typing.TypeUInt32: 'i32',
typing.TypeUInt64: 'i64',
typing.TypeInt32: 'i32',
typing.TypeInt64: 'i64',
typing.TypeFloat32: 'f32',
typing.TypeFloat64: 'f64',
}
"""
When generating code, we sometimes need to load or store simple values
"""
def phasm_compile(inp: ourlang.Module) -> wasm.Module: def phasm_compile(inp: ourlang.Module) -> wasm.Module:
""" """
Public method for compiling a parsed Phasm module into Public method for compiling a parsed Phasm module into
@ -32,42 +21,57 @@ def phasm_compile(inp: ourlang.Module) -> wasm.Module:
""" """
return module(inp) return module(inp)
def type_(inp: typing.TypeBase) -> wasm.WasmType: def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType:
""" """
Compile: type Compile: type
"""
if isinstance(inp, typing.TypeNone):
return wasm.WasmTypeNone()
if isinstance(inp, typing.TypeUInt8): Types are used for example in WebAssembly function parameters
and return types.
"""
assert inp is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp)
if mtyp == 'u8':
# WebAssembly has only support for 32 and 64 bits # WebAssembly has only support for 32 and 64 bits
# So we need to store more memory per byte # So we need to store more memory per byte
return wasm.WasmTypeInt32() return wasm.WasmTypeInt32()
if isinstance(inp, typing.TypeUInt32): if mtyp == 'u32':
return wasm.WasmTypeInt32() return wasm.WasmTypeInt32()
if isinstance(inp, typing.TypeUInt64): if mtyp == 'u64':
return wasm.WasmTypeInt64() return wasm.WasmTypeInt64()
if isinstance(inp, typing.TypeInt32): if mtyp == 'i32':
return wasm.WasmTypeInt32() return wasm.WasmTypeInt32()
if isinstance(inp, typing.TypeInt64): if mtyp == 'i64':
return wasm.WasmTypeInt64() return wasm.WasmTypeInt64()
if isinstance(inp, typing.TypeFloat32): if mtyp == 'f32':
return wasm.WasmTypeFloat32() return wasm.WasmTypeFloat32()
if isinstance(inp, typing.TypeFloat64): if mtyp == 'f64':
return wasm.WasmTypeFloat64() return wasm.WasmTypeFloat64()
if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)): assert inp is not None, typing.ASSERTION_ERROR
# Structs and tuples are passed as pointer tc_prim = inp.get_constraint(typing.TypeConstraintPrimitive)
if tc_prim is None:
raise NotImplementedError(type_var, inp)
if tc_prim.primitive is typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
# StaticArray, Tuples and Structs are passed as pointer
# And pointers are i32 # And pointers are i32
return wasm.WasmTypeInt32() return wasm.WasmTypeInt32()
raise NotImplementedError(type_, inp) # TODO: Broken after new type system
# if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)):
# # Structs and tuples are passed as pointer
# # And pointers are i32
# return wasm.WasmTypeInt32()
raise NotImplementedError(inp, mtyp)
# Operators that work for i32, i64, f32, f64 # Operators that work for i32, i64, f32, f64
OPERATOR_MAP = { OPERATOR_MAP = {
@ -81,8 +85,6 @@ U8_OPERATOR_MAP = {
# Under the hood, this is an i32 # Under the hood, this is an i32
# Implementing Right Shift XOR, OR, AND is fine since the 3 remaining # Implementing Right Shift XOR, OR, AND is fine since the 3 remaining
# bytes stay zero after this operation # bytes stay zero after this operation
# Since it's unsigned an unsigned value, Logical or Arithmetic shift right
# are the same operation
'>>': 'shr_u', '>>': 'shr_u',
'^': 'xor', '^': 'xor',
'|': 'or', '|': 'or',
@ -131,109 +133,159 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
""" """
Compile: Any expression Compile: Any expression
""" """
if isinstance(inp, ourlang.ConstantUInt8): if isinstance(inp, ourlang.ConstantPrimitive):
assert inp.type_var is not None, typing.ASSERTION_ERROR
stp = typing.simplify(inp.type_var)
if stp is None:
raise NotImplementedError(f'Constants with type {inp.type_var}')
if stp == 'u8':
# No native u8 type - treat as i32, with caution
assert isinstance(inp.value, int)
wgn.i32.const(inp.value) wgn.i32.const(inp.value)
return return
if isinstance(inp, ourlang.ConstantUInt32): if stp in ('i32', 'u32'):
assert isinstance(inp.value, int)
wgn.i32.const(inp.value) wgn.i32.const(inp.value)
return return
if isinstance(inp, ourlang.ConstantUInt64): if stp in ('i64', 'u64'):
assert isinstance(inp.value, int)
wgn.i64.const(inp.value) wgn.i64.const(inp.value)
return return
if isinstance(inp, ourlang.ConstantInt32): if stp == 'f32':
wgn.i32.const(inp.value) assert isinstance(inp.value, float)
return
if isinstance(inp, ourlang.ConstantInt64):
wgn.i64.const(inp.value)
return
if isinstance(inp, ourlang.ConstantFloat32):
wgn.f32.const(inp.value) wgn.f32.const(inp.value)
return return
if isinstance(inp, ourlang.ConstantFloat64): if stp == 'f64':
assert isinstance(inp.value, float)
wgn.f64.const(inp.value) wgn.f64.const(inp.value)
return return
raise NotImplementedError(f'Constants with type {stp}')
if isinstance(inp, ourlang.VariableReference): if isinstance(inp, ourlang.VariableReference):
wgn.add_statement('local.get', '${}'.format(inp.name)) if isinstance(inp.variable, ourlang.FunctionParam):
wgn.add_statement('local.get', '${}'.format(inp.variable.name))
return return
if isinstance(inp.variable, ourlang.ModuleConstantDef):
assert inp.variable.type_var is not None, typing.ASSERTION_ERROR
tc_prim = inp.variable.type_var.get_constraint(typing.TypeConstraintPrimitive)
if tc_prim is None:
raise NotImplementedError(expression, inp, inp.variable.type_var)
# TODO: Broken after new type system
# if isinstance(inp.type, typing.TypeTuple):
# assert isinstance(inp.definition.constant, ourlang.ConstantTuple)
# assert inp.definition.data_block is not None, 'Combined values are memory stored'
# assert inp.definition.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.definition.data_block.address)
# return
#
if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
assert inp.variable.data_block is not None, 'Combined values are memory stored'
assert inp.variable.data_block.address is not None, 'Value not allocated'
wgn.i32.const(inp.variable.data_block.address)
return
assert inp.variable.data_block is None, 'Primitives are not memory stored'
assert inp.variable.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.variable.type_var)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, inp.type_var)
expression(wgn, inp.variable.constant)
return
raise NotImplementedError(expression, inp.variable)
if isinstance(inp, ourlang.BinaryOp): if isinstance(inp, ourlang.BinaryOp):
expression(wgn, inp.left) expression(wgn, inp.left)
expression(wgn, inp.right) expression(wgn, inp.right)
if isinstance(inp.type, typing.TypeUInt8): assert inp.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.type_var)
if mtyp == 'u8':
if operator := U8_OPERATOR_MAP.get(inp.operator, None): if operator := U8_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if isinstance(inp.type, typing.TypeUInt32): if mtyp == 'u32':
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if operator := U32_OPERATOR_MAP.get(inp.operator, None): if operator := U32_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if isinstance(inp.type, typing.TypeUInt64): if mtyp == 'u64':
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}') wgn.add_statement(f'i64.{operator}')
return return
if operator := U64_OPERATOR_MAP.get(inp.operator, None): if operator := U64_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}') wgn.add_statement(f'i64.{operator}')
return return
if isinstance(inp.type, typing.TypeInt32): if mtyp == 'i32':
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if operator := I32_OPERATOR_MAP.get(inp.operator, None): if operator := I32_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if isinstance(inp.type, typing.TypeInt64): if mtyp == 'i64':
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}') wgn.add_statement(f'i64.{operator}')
return return
if operator := I64_OPERATOR_MAP.get(inp.operator, None): if operator := I64_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}') wgn.add_statement(f'i64.{operator}')
return return
if isinstance(inp.type, typing.TypeFloat32): if mtyp == 'f32':
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'f32.{operator}') wgn.add_statement(f'f32.{operator}')
return return
if isinstance(inp.type, typing.TypeFloat64): if mtyp == 'f64':
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'f64.{operator}') wgn.add_statement(f'f64.{operator}')
return return
raise NotImplementedError(expression, inp.type, inp.operator) raise NotImplementedError(expression, inp.type_var, inp.operator)
if isinstance(inp, ourlang.UnaryOp): if isinstance(inp, ourlang.UnaryOp):
expression(wgn, inp.right) expression(wgn, inp.right)
if isinstance(inp.type, typing.TypeFloat32): assert inp.type_var is not None, typing.ASSERTION_ERROR
if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: mtyp = typing.simplify(inp.type_var)
if mtyp == 'f32':
if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS:
wgn.add_statement(f'f32.{inp.operator}') wgn.add_statement(f'f32.{inp.operator}')
return return
if isinstance(inp.type, typing.TypeFloat64): if mtyp == 'f64':
if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS:
wgn.add_statement(f'f64.{inp.operator}') wgn.add_statement(f'f64.{inp.operator}')
return return
if isinstance(inp.type, typing.TypeInt32): # TODO: Broken after new type system
if inp.operator == 'len': # if isinstance(inp.type, typing.TypeInt32):
if isinstance(inp.right.type, typing.TypeBytes): # if inp.operator == 'len':
wgn.i32.load() # if isinstance(inp.right.type, typing.TypeBytes):
return # wgn.i32.load()
# return
if inp.operator == 'cast': # if inp.operator == 'cast':
if isinstance(inp.type, typing.TypeUInt32) and isinstance(inp.right.type, typing.TypeUInt8): # if isinstance(inp.type, typing.TypeUInt32) and isinstance(inp.right.type, typing.TypeUInt8):
# Nothing to do, you can use an u8 value as a u32 no problem # # Nothing to do, you can use an u8 value as a u32 no problem
return # return
raise NotImplementedError(expression, inp.type, inp.operator) raise NotImplementedError(expression, inp.type_var, inp.operator)
if isinstance(inp, ourlang.FunctionCall): if isinstance(inp, ourlang.FunctionCall):
for arg in inp.arguments: for arg in inp.arguments:
@ -242,99 +294,119 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
wgn.add_statement('call', '${}'.format(inp.function.name)) wgn.add_statement('call', '${}'.format(inp.function.name))
return return
if isinstance(inp, ourlang.AccessBytesIndex): if isinstance(inp, ourlang.Subscript):
if not isinstance(inp.type, typing.TypeUInt8): assert inp.varref.type_var is not None, typing.ASSERTION_ERROR
raise NotImplementedError(inp, inp.type) tc_prim = inp.varref.type_var.get_constraint(typing.TypeConstraintPrimitive)
if tc_prim is None:
raise NotImplementedError(expression, inp, inp.varref.type_var)
expression(wgn, inp.varref) if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
expression(wgn, inp.index) if not isinstance(inp.index, ourlang.ConstantPrimitive):
wgn.call(stdlib_types.__subscript_bytes__) raise NotImplementedError(expression, inp, inp.index)
return if not isinstance(inp.index.value, int):
raise NotImplementedError(expression, inp, inp.index.value)
if isinstance(inp, ourlang.AccessStructMember): assert inp.type_var is not None, typing.ASSERTION_ERROR
mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) mtyp = typing.simplify(inp.type_var)
if mtyp is None: if mtyp is None:
# In the future might extend this by having structs or tuples raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp)
# as members of struct or tuples
raise NotImplementedError(expression, inp, inp.member) if mtyp == 'u8':
# u8 operations are done using i32, since WASM does not have u8 operations
mtyp = 'i32'
elif mtyp == 'u32':
# u32 operations are done using i32, using _u operations
mtyp = 'i32'
elif mtyp == 'u64':
# u64 operations are done using i64, using _u operations
mtyp = 'i64'
tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript)
if tc_subs is None:
raise NotImplementedError(expression, inp, inp.varref.type_var)
assert 0 < len(tc_subs.members)
tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth)
if tc_bits is None or len(tc_bits.oneof) > 1:
raise NotImplementedError(expression, inp, inp.varref.type_var)
bitwidth = next(iter(tc_bits.oneof))
if bitwidth % 8 != 0:
raise NotImplementedError(expression, inp, inp.varref.type_var)
expression(wgn, inp.varref) expression(wgn, inp.varref)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value))
return return
if isinstance(inp, ourlang.AccessTupleMember): raise NotImplementedError(expression, inp, inp.varref.type_var)
mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, inp.member)
expression(wgn, inp.varref)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
return
if isinstance(inp, ourlang.AccessStaticArrayMember): # TODO: Broken after new type system
mtyp = LOAD_STORE_TYPE_MAP.get(inp.static_array.member_type.__class__) # if isinstance(inp, ourlang.AccessBytesIndex):
if mtyp is None: # if not isinstance(inp.type, typing.TypeUInt8):
# In the future might extend this by having structs or tuples # raise NotImplementedError(inp, inp.type)
# as members of static arrays #
raise NotImplementedError(expression, inp, inp.member) # expression(wgn, inp.varref)
# expression(wgn, inp.index)
# wgn.call(stdlib_types.__subscript_bytes__)
# return
if isinstance(inp.member, typing.TypeStaticArrayMember): # if isinstance(inp, ourlang.AccessStructMember):
expression(wgn, inp.varref) # mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) # if mtyp is None:
return # # In the future might extend this by having structs or tuples
# # as members of struct or tuples
expression(wgn, inp.varref) # raise NotImplementedError(expression, inp, inp.member)
expression(wgn, inp.member) #
wgn.i32.const(inp.static_array.member_type.alloc_size()) # expression(wgn, inp.varref)
wgn.i32.mul() # wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
wgn.i32.add() # return
wgn.add_statement(f'{mtyp}.load') #
return # if isinstance(inp, ourlang.AccessTupleMember):
# mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__)
# if mtyp is None:
# # In the future might extend this by having structs or tuples
# # as members of struct or tuples
# raise NotImplementedError(expression, inp, inp.member)
#
# expression(wgn, inp.varref)
# wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
# return
#
# if isinstance(inp, ourlang.AccessStaticArrayMember):
# mtyp = LOAD_STORE_TYPE_MAP.get(inp.static_array.member_type.__class__)
# if mtyp is None:
# # In the future might extend this by having structs or tuples
# # as members of static arrays
# raise NotImplementedError(expression, inp, inp.member)
#
# expression(wgn, inp.varref)
# expression(wgn, inp.member)
# wgn.i32.const(inp.static_array.member_type.alloc_size())
# wgn.i32.mul()
# wgn.i32.add()
# wgn.add_statement(f'{mtyp}.load')
# return
if isinstance(inp, ourlang.Fold): if isinstance(inp, ourlang.Fold):
expression_fold(wgn, inp) expression_fold(wgn, inp)
return return
if isinstance(inp, ourlang.ModuleConstantReference):
if isinstance(inp.type, typing.TypeTuple):
assert isinstance(inp.definition.constant, ourlang.ConstantTuple)
assert inp.definition.data_block is not None, 'Combined values are memory stored'
assert inp.definition.data_block.address is not None, 'Value not allocated'
wgn.i32.const(inp.definition.data_block.address)
return
if isinstance(inp.type, typing.TypeStaticArray):
assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray)
assert inp.definition.data_block is not None, 'Combined values are memory stored'
assert inp.definition.data_block.address is not None, 'Value not allocated'
wgn.i32.const(inp.definition.data_block.address)
return
assert inp.definition.data_block is None, 'Primitives are not memory stored'
mtyp = LOAD_STORE_TYPE_MAP.get(inp.type.__class__)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, inp.type)
expression(wgn, inp.definition.constant)
return
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
""" """
Compile: Fold expression Compile: Fold expression
""" """
mtyp = LOAD_STORE_TYPE_MAP.get(inp.base.type.__class__) assert inp.base.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.base.type_var)
if mtyp is None: if mtyp is None:
# In the future might extend this by having structs or tuples # In the future might extend this by having structs or tuples
# as members of struct or tuples # as members of struct or tuples
raise NotImplementedError(expression, inp, inp.base) raise NotImplementedError(expression, inp, inp.base)
raise NotImplementedError('TODO: Broken after new type system')
if inp.iter.type.__class__.__name__ != 'TypeBytes': if inp.iter.type.__class__.__name__ != 'TypeBytes':
raise NotImplementedError(expression, inp, inp.iter.type) raise NotImplementedError(expression, inp, inp.iter.type)
@ -450,7 +522,7 @@ def function_argument(inp: ourlang.FunctionParam) -> wasm.Param:
""" """
Compile: function argument Compile: function argument
""" """
return (inp[0], type_(inp[1]), ) return (inp.name, type_var(inp.type_var), )
def import_(inp: ourlang.Function) -> wasm.Import: def import_(inp: ourlang.Function) -> wasm.Import:
""" """
@ -466,7 +538,7 @@ def import_(inp: ourlang.Function) -> wasm.Import:
function_argument(x) function_argument(x)
for x in inp.posonlyargs for x in inp.posonlyargs
], ],
type_(inp.returns) type_var(inp.returns_type_var)
) )
def function(inp: ourlang.Function) -> wasm.Function: def function(inp: ourlang.Function) -> wasm.Function:
@ -477,10 +549,10 @@ def function(inp: ourlang.Function) -> wasm.Function:
wgn = WasmGenerator() wgn = WasmGenerator()
if isinstance(inp, ourlang.TupleConstructor): if False: # TODO: isinstance(inp, ourlang.TupleConstructor):
_generate_tuple_constructor(wgn, inp) pass # _generate_tuple_constructor(wgn, inp)
elif isinstance(inp, ourlang.StructConstructor): elif False: # TODO: isinstance(inp, ourlang.StructConstructor):
_generate_struct_constructor(wgn, inp) pass # _generate_struct_constructor(wgn, inp)
else: else:
for stat in inp.statements: for stat in inp.statements:
statement(wgn, stat) statement(wgn, stat)
@ -496,7 +568,7 @@ def function(inp: ourlang.Function) -> wasm.Function:
(k, v.wasm_type(), ) (k, v.wasm_type(), )
for k, v in wgn.locals.items() for k, v in wgn.locals.items()
], ],
type_(inp.returns), type_var(inp.returns_type_var),
wgn.statements wgn.statements
) )
@ -555,38 +627,48 @@ def module_data(inp: ourlang.ModuleData) -> bytes:
for block in inp.blocks: for block in inp.blocks:
block.address = unalloc_ptr + 4 # 4 bytes for allocator header block.address = unalloc_ptr + 4 # 4 bytes for allocator header
data_list = [] data_list: List[bytes] = []
for constant in block.data: for constant in block.data:
if isinstance(constant, ourlang.ConstantUInt8): assert constant.type_var is not None
mtyp = typing.simplify(constant.type_var)
if mtyp == 'u8':
assert isinstance(constant.value, int)
data_list.append(module_data_u8(constant.value)) data_list.append(module_data_u8(constant.value))
continue continue
if isinstance(constant, ourlang.ConstantUInt32): if mtyp == 'u32':
assert isinstance(constant.value, int)
data_list.append(module_data_u32(constant.value)) data_list.append(module_data_u32(constant.value))
continue continue
if isinstance(constant, ourlang.ConstantUInt64): if mtyp == 'u64':
assert isinstance(constant.value, int)
data_list.append(module_data_u64(constant.value)) data_list.append(module_data_u64(constant.value))
continue continue
if isinstance(constant, ourlang.ConstantInt32): if mtyp == 'i32':
assert isinstance(constant.value, int)
data_list.append(module_data_i32(constant.value)) data_list.append(module_data_i32(constant.value))
continue continue
if isinstance(constant, ourlang.ConstantInt64): if mtyp == 'i64':
assert isinstance(constant.value, int)
data_list.append(module_data_i64(constant.value)) data_list.append(module_data_i64(constant.value))
continue continue
if isinstance(constant, ourlang.ConstantFloat32): if mtyp == 'f32':
assert isinstance(constant.value, float)
data_list.append(module_data_f32(constant.value)) data_list.append(module_data_f32(constant.value))
continue continue
if isinstance(constant, ourlang.ConstantFloat64): if mtyp == 'f64':
assert isinstance(constant.value, float)
data_list.append(module_data_f64(constant.value)) data_list.append(module_data_f64(constant.value))
continue continue
raise NotImplementedError(constant) raise NotImplementedError(constant, constant.type_var, mtyp)
block_data = b''.join(data_list) block_data = b''.join(data_list)
@ -636,48 +718,49 @@ def module(inp: ourlang.Module) -> wasm.Module:
return result return result
def _generate_tuple_constructor(wgn: WasmGenerator, inp: ourlang.TupleConstructor) -> None: # TODO: Broken after new type system
tmp_var = wgn.temp_var_i32('tuple_adr') # def _generate_tuple_constructor(wgn: WasmGenerator, inp: ourlang.TupleConstructor) -> None:
# tmp_var = wgn.temp_var_i32('tuple_adr')
# Allocated the required amounts of bytes in memory #
wgn.i32.const(inp.tuple.alloc_size()) # # Allocated the required amounts of bytes in memory
wgn.call(stdlib_alloc.__alloc__) # wgn.i32.const(inp.tuple.alloc_size())
wgn.local.set(tmp_var) # wgn.call(stdlib_alloc.__alloc__)
# wgn.local.set(tmp_var)
# Store each member individually #
for member in inp.tuple.members: # # Store each member individually
mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) # for member in inp.tuple.members:
if mtyp is None: # mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__)
# In the future might extend this by having structs or tuples # if mtyp is None:
# as members of struct or tuples # # In the future might extend this by having structs or tuples
raise NotImplementedError(expression, inp, member) # # as members of struct or tuples
# raise NotImplementedError(expression, inp, member)
wgn.local.get(tmp_var) #
wgn.add_statement('local.get', f'$arg{member.idx}') # wgn.local.get(tmp_var)
wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset)) # wgn.add_statement('local.get', f'$arg{member.idx}')
# wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset))
# Return the allocated address #
wgn.local.get(tmp_var) # # Return the allocated address
# wgn.local.get(tmp_var)
def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None: #
tmp_var = wgn.temp_var_i32('struct_adr') # def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None:
# tmp_var = wgn.temp_var_i32('struct_adr')
# Allocated the required amounts of bytes in memory #
wgn.i32.const(inp.struct.alloc_size()) # # Allocated the required amounts of bytes in memory
wgn.call(stdlib_alloc.__alloc__) # wgn.i32.const(inp.struct.alloc_size())
wgn.local.set(tmp_var) # wgn.call(stdlib_alloc.__alloc__)
# wgn.local.set(tmp_var)
# Store each member individually #
for member in inp.struct.members: # # Store each member individually
mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) # for member in inp.struct.members:
if mtyp is None: # mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__)
# In the future might extend this by having structs or tuples # if mtyp is None:
# as members of struct or tuples # # In the future might extend this by having structs or tuples
raise NotImplementedError(expression, inp, member) # # as members of struct or tuples
# raise NotImplementedError(expression, inp, member)
wgn.local.get(tmp_var) #
wgn.add_statement('local.get', f'${member.name}') # wgn.local.get(tmp_var)
wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset)) # wgn.add_statement('local.get', f'${member.name}')
# wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset))
# Return the allocated address #
wgn.local.get(tmp_var) # # Return the allocated address
# wgn.local.get(tmp_var)

View File

@ -6,3 +6,8 @@ class StaticError(Exception):
""" """
An error found during static analysis An error found during static analysis
""" """
class TypingError(Exception):
"""
An error found during the typing phase
"""

View File

@ -1,38 +1,31 @@
""" """
Contains the syntax tree for ourlang Contains the syntax tree for ourlang
""" """
from typing import Dict, List, Tuple, Optional, Union from typing import Dict, List, Optional, Union
import enum import enum
from typing_extensions import Final from typing_extensions import Final
WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', ) WEBASSEMBLY_BUILTIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', )
WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', ) WEBASSEMBLY_BUILTIN_BYTES_OPS: Final = ('len', )
from .typing import ( from .typing import (
TypeBase, TypeStruct,
TypeNone,
TypeBool, TypeVar,
TypeUInt8, TypeUInt32, TypeUInt64,
TypeInt32, TypeInt64,
TypeFloat32, TypeFloat64,
TypeBytes,
TypeTuple, TypeTupleMember,
TypeStaticArray, TypeStaticArrayMember,
TypeStruct, TypeStructMember,
) )
class Expression: class Expression:
""" """
An expression within a statement An expression within a statement
""" """
__slots__ = ('type', ) __slots__ = ('type_var', )
type: TypeBase type_var: Optional[TypeVar]
def __init__(self, type_: TypeBase) -> None: def __init__(self) -> None:
self.type = type_ self.type_var = None
class Constant(Expression): class Constant(Expression):
""" """
@ -40,88 +33,16 @@ class Constant(Expression):
""" """
__slots__ = () __slots__ = ()
class ConstantUInt8(Constant): class ConstantPrimitive(Constant):
""" """
An UInt8 constant value expression within a statement An primitive constant value expression within a statement
""" """
__slots__ = ('value', ) __slots__ = ('value', )
value: int value: Union[int, float]
def __init__(self, type_: TypeUInt8, value: int) -> None: def __init__(self, value: Union[int, float]) -> None:
super().__init__(type_) super().__init__()
self.value = value
class ConstantUInt32(Constant):
"""
An UInt32 constant value expression within a statement
"""
__slots__ = ('value', )
value: int
def __init__(self, type_: TypeUInt32, value: int) -> None:
super().__init__(type_)
self.value = value
class ConstantUInt64(Constant):
"""
An UInt64 constant value expression within a statement
"""
__slots__ = ('value', )
value: int
def __init__(self, type_: TypeUInt64, value: int) -> None:
super().__init__(type_)
self.value = value
class ConstantInt32(Constant):
"""
An Int32 constant value expression within a statement
"""
__slots__ = ('value', )
value: int
def __init__(self, type_: TypeInt32, value: int) -> None:
super().__init__(type_)
self.value = value
class ConstantInt64(Constant):
"""
An Int64 constant value expression within a statement
"""
__slots__ = ('value', )
value: int
def __init__(self, type_: TypeInt64, value: int) -> None:
super().__init__(type_)
self.value = value
class ConstantFloat32(Constant):
"""
An Float32 constant value expression within a statement
"""
__slots__ = ('value', )
value: float
def __init__(self, type_: TypeFloat32, value: float) -> None:
super().__init__(type_)
self.value = value
class ConstantFloat64(Constant):
"""
An Float64 constant value expression within a statement
"""
__slots__ = ('value', )
value: float
def __init__(self, type_: TypeFloat64, value: float) -> None:
super().__init__(type_)
self.value = value self.value = value
class ConstantTuple(Constant): class ConstantTuple(Constant):
@ -130,35 +51,23 @@ class ConstantTuple(Constant):
""" """
__slots__ = ('value', ) __slots__ = ('value', )
value: List[Constant] value: List[ConstantPrimitive]
def __init__(self, type_: TypeTuple, value: List[Constant]) -> None: def __init__(self, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples?
super().__init__(type_) super().__init__()
self.value = value
class ConstantStaticArray(Constant):
"""
A StaticArray constant value expression within a statement
"""
__slots__ = ('value', )
value: List[Constant]
def __init__(self, type_: TypeStaticArray, value: List[Constant]) -> None:
super().__init__(type_)
self.value = value self.value = value
class VariableReference(Expression): class VariableReference(Expression):
""" """
An variable reference expression within a statement An variable reference expression within a statement
""" """
__slots__ = ('name', ) __slots__ = ('variable', )
name: str variable: Union['ModuleConstantDef', 'FunctionParam'] # also possibly local
def __init__(self, type_: TypeBase, name: str) -> None: def __init__(self, variable: Union['ModuleConstantDef', 'FunctionParam']) -> None:
super().__init__(type_) super().__init__()
self.name = name self.variable = variable
class UnaryOp(Expression): class UnaryOp(Expression):
""" """
@ -169,8 +78,8 @@ class UnaryOp(Expression):
operator: str operator: str
right: Expression right: Expression
def __init__(self, type_: TypeBase, operator: str, right: Expression) -> None: def __init__(self, operator: str, right: Expression) -> None:
super().__init__(type_) super().__init__()
self.operator = operator self.operator = operator
self.right = right self.right = right
@ -185,8 +94,8 @@ class BinaryOp(Expression):
left: Expression left: Expression
right: Expression right: Expression
def __init__(self, type_: TypeBase, operator: str, left: Expression, right: Expression) -> None: def __init__(self, operator: str, left: Expression, right: Expression) -> None:
super().__init__(type_) super().__init__()
self.operator = operator self.operator = operator
self.left = left self.left = left
@ -202,73 +111,27 @@ class FunctionCall(Expression):
arguments: List[Expression] arguments: List[Expression]
def __init__(self, function: 'Function') -> None: def __init__(self, function: 'Function') -> None:
super().__init__(function.returns) super().__init__()
self.function = function self.function = function
self.arguments = [] self.arguments = []
class AccessBytesIndex(Expression): class Subscript(Expression):
""" """
Access a bytes index for reading A subscript, for example to refer to a static array or tuple
by index
""" """
__slots__ = ('varref', 'index', ) __slots__ = ('varref', 'index', )
varref: VariableReference varref: VariableReference
index: Expression index: Expression
def __init__(self, type_: TypeBase, varref: VariableReference, index: Expression) -> None: def __init__(self, varref: VariableReference, index: Expression) -> None:
super().__init__(type_) super().__init__()
self.varref = varref self.varref = varref
self.index = index self.index = index
class AccessStructMember(Expression):
"""
Access a struct member for reading of writing
"""
__slots__ = ('varref', 'member', )
varref: VariableReference
member: TypeStructMember
def __init__(self, varref: VariableReference, member: TypeStructMember) -> None:
super().__init__(member.type)
self.varref = varref
self.member = member
class AccessTupleMember(Expression):
"""
Access a tuple member for reading of writing
"""
__slots__ = ('varref', 'member', )
varref: VariableReference
member: TypeTupleMember
def __init__(self, varref: VariableReference, member: TypeTupleMember, ) -> None:
super().__init__(member.type)
self.varref = varref
self.member = member
class AccessStaticArrayMember(Expression):
"""
Access a tuple member for reading of writing
"""
__slots__ = ('varref', 'static_array', 'member', )
varref: Union['ModuleConstantReference', VariableReference]
static_array: TypeStaticArray
member: Union[Expression, TypeStaticArrayMember]
def __init__(self, varref: Union['ModuleConstantReference', VariableReference], static_array: TypeStaticArray, member: Union[TypeStaticArrayMember, Expression], ) -> None:
super().__init__(static_array.member_type)
self.varref = varref
self.static_array = static_array
self.member = member
class Fold(Expression): class Fold(Expression):
""" """
A (left or right) fold A (left or right) fold
@ -287,31 +150,18 @@ class Fold(Expression):
def __init__( def __init__(
self, self,
type_: TypeBase,
dir_: Direction, dir_: Direction,
func: 'Function', func: 'Function',
base: Expression, base: Expression,
iter_: Expression, iter_: Expression,
) -> None: ) -> None:
super().__init__(type_) super().__init__()
self.dir = dir_ self.dir = dir_
self.func = func self.func = func
self.base = base self.base = base
self.iter = iter_ self.iter = iter_
class ModuleConstantReference(Expression):
"""
An reference to a module constant expression within a statement
"""
__slots__ = ('definition', )
definition: 'ModuleConstantDef'
def __init__(self, type_: TypeBase, definition: 'ModuleConstantDef') -> None:
super().__init__(type_)
self.definition = definition
class Statement: class Statement:
""" """
A statement within a function A statement within a function
@ -348,20 +198,34 @@ class StatementIf(Statement):
self.statements = [] self.statements = []
self.else_statements = [] self.else_statements = []
FunctionParam = Tuple[str, TypeBase] class FunctionParam:
"""
A parameter for a Function
"""
__slots__ = ('name', 'type_str', 'type_var', )
name: str
type_str: str
type_var: Optional[TypeVar]
def __init__(self, name: str, type_str: str) -> None:
self.name = name
self.type_str = type_str
self.type_var = None
class Function: class Function:
""" """
A function processes input and produces output A function processes input and produces output
""" """
__slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'posonlyargs', ) __slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns_str', 'returns_type_var', 'posonlyargs', )
name: str name: str
lineno: int lineno: int
exported: bool exported: bool
imported: bool imported: bool
statements: List[Statement] statements: List[Statement]
returns: TypeBase returns_str: str
returns_type_var: Optional[TypeVar]
posonlyargs: List[FunctionParam] posonlyargs: List[FunctionParam]
def __init__(self, name: str, lineno: int) -> None: def __init__(self, name: str, lineno: int) -> None:
@ -370,66 +234,70 @@ class Function:
self.exported = False self.exported = False
self.imported = False self.imported = False
self.statements = [] self.statements = []
self.returns = TypeNone() self.returns_str = 'None'
self.returns_type_var = None
self.posonlyargs = [] self.posonlyargs = []
class StructConstructor(Function): # TODO: Broken after new type system
""" # class StructConstructor(Function):
The constructor method for a struct # """
# The constructor method for a struct
A function will generated to instantiate a struct. The arguments #
will be the defaults # A function will generated to instantiate a struct. The arguments
""" # will be the defaults
__slots__ = ('struct', ) # """
# __slots__ = ('struct', )
struct: TypeStruct #
# struct: TypeStruct
def __init__(self, struct: TypeStruct) -> None: #
super().__init__(f'@{struct.name}@__init___@', -1) # def __init__(self, struct: TypeStruct) -> None:
# super().__init__(f'@{struct.name}@__init___@', -1)
self.returns = struct #
# self.returns = struct
for mem in struct.members: #
self.posonlyargs.append((mem.name, mem.type, )) # for mem in struct.members:
# self.posonlyargs.append(FunctionParam(mem.name, mem.type, ))
self.struct = struct #
# self.struct = struct
class TupleConstructor(Function): #
""" # class TupleConstructor(Function):
The constructor method for a tuple # """
""" # The constructor method for a tuple
__slots__ = ('tuple', ) # """
# __slots__ = ('tuple', )
tuple: TypeTuple #
# tuple: TypeTuple
def __init__(self, tuple_: TypeTuple) -> None: #
name = tuple_.render_internal_name() # def __init__(self, tuple_: TypeTuple) -> None:
# name = tuple_.render_internal_name()
super().__init__(f'@{name}@__init___@', -1) #
# super().__init__(f'@{name}@__init___@', -1)
self.returns = tuple_ #
# self.returns = tuple_
for mem in tuple_.members: #
self.posonlyargs.append((f'arg{mem.idx}', mem.type, )) # for mem in tuple_.members:
# self.posonlyargs.append(FunctionParam(f'arg{mem.idx}', mem.type, ))
self.tuple = tuple_ #
# self.tuple = tuple_
class ModuleConstantDef: class ModuleConstantDef:
""" """
A constant definition within a module A constant definition within a module
""" """
__slots__ = ('name', 'lineno', 'type', 'constant', 'data_block', ) __slots__ = ('name', 'lineno', 'type_str', 'type_var', 'constant', 'data_block', )
name: str name: str
lineno: int lineno: int
type: TypeBase type_str: str
type_var: Optional[TypeVar]
constant: Constant constant: Constant
data_block: Optional['ModuleDataBlock'] data_block: Optional['ModuleDataBlock']
def __init__(self, name: str, lineno: int, type_: TypeBase, constant: Constant, data_block: Optional['ModuleDataBlock']) -> None: def __init__(self, name: str, lineno: int, type_str: str, constant: Constant, data_block: Optional['ModuleDataBlock']) -> None:
self.name = name self.name = name
self.lineno = lineno self.lineno = lineno
self.type = type_ self.type_str = type_str
self.type_var = None
self.constant = constant self.constant = constant
self.data_block = data_block self.data_block = data_block
@ -439,10 +307,10 @@ class ModuleDataBlock:
""" """
__slots__ = ('data', 'address', ) __slots__ = ('data', 'address', )
data: List[Constant] data: List[ConstantPrimitive]
address: Optional[int] address: Optional[int]
def __init__(self, data: List[Constant]) -> None: def __init__(self, data: List[ConstantPrimitive]) -> None:
self.data = data self.data = data
self.address = None self.address = None
@ -464,23 +332,11 @@ class Module:
__slots__ = ('data', 'types', 'structs', 'constant_defs', 'functions',) __slots__ = ('data', 'types', 'structs', 'constant_defs', 'functions',)
data: ModuleData data: ModuleData
types: Dict[str, TypeBase]
structs: Dict[str, TypeStruct] structs: Dict[str, TypeStruct]
constant_defs: Dict[str, ModuleConstantDef] constant_defs: Dict[str, ModuleConstantDef]
functions: Dict[str, Function] functions: Dict[str, Function]
def __init__(self) -> None: def __init__(self) -> None:
self.types = {
'None': TypeNone(),
'u8': TypeUInt8(),
'u32': TypeUInt32(),
'u64': TypeUInt64(),
'i32': TypeInt32(),
'i64': TypeInt64(),
'f32': TypeFloat32(),
'f64': TypeFloat64(),
'bytes': TypeBytes(),
}
self.data = ModuleData() self.data = ModuleData()
self.structs = {} self.structs = {}
self.constant_defs = {} self.constant_defs = {}

View File

@ -6,48 +6,33 @@ from typing import Any, Dict, NoReturn, Union
import ast import ast
from .typing import ( from .typing import (
TypeBase, BUILTIN_TYPES,
TypeUInt8,
TypeUInt32,
TypeUInt64,
TypeInt32,
TypeInt64,
TypeFloat32,
TypeFloat64,
TypeBytes,
TypeStruct, TypeStruct,
TypeStructMember, TypeStructMember,
TypeTuple,
TypeTupleMember,
TypeStaticArray,
TypeStaticArrayMember,
) )
from . import codestyle
from .exceptions import StaticError from .exceptions import StaticError
from .ourlang import ( from .ourlang import (
WEBASSEMBLY_BUILDIN_FLOAT_OPS, WEBASSEMBLY_BUILTIN_FLOAT_OPS,
Module, ModuleDataBlock, Module, ModuleDataBlock,
Function, Function,
Expression, Expression,
AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember,
BinaryOp, BinaryOp,
Constant, ConstantPrimitive, ConstantTuple,
ConstantFloat32, ConstantFloat64, ConstantInt32, ConstantInt64,
ConstantUInt8, ConstantUInt32, ConstantUInt64,
ConstantTuple, ConstantStaticArray,
FunctionCall, FunctionCall, Subscript,
StructConstructor, TupleConstructor, # StructConstructor, TupleConstructor,
UnaryOp, VariableReference, UnaryOp, VariableReference,
Fold, ModuleConstantReference, Fold,
Statement, Statement,
StatementIf, StatementPass, StatementReturn, StatementIf, StatementPass, StatementReturn,
FunctionParam,
ModuleConstantDef, ModuleConstantDef,
) )
@ -60,7 +45,7 @@ def phasm_parse(source: str) -> Module:
our_visitor = OurVisitor() our_visitor = OurVisitor()
return our_visitor.visit_Module(res) return our_visitor.visit_Module(res)
OurLocals = Dict[str, TypeBase] OurLocals = Dict[str, Union[FunctionParam]] # Also local variable and module constants?
class OurVisitor: class OurVisitor:
""" """
@ -95,15 +80,16 @@ class OurVisitor:
module.constant_defs[res.name] = res module.constant_defs[res.name] = res
if isinstance(res, TypeStruct): # TODO: Broken after type system
if res.name in module.structs: # if isinstance(res, TypeStruct):
raise StaticError( # if res.name in module.structs:
f'{res.name} already defined on line {module.structs[res.name].lineno}' # raise StaticError(
) # f'{res.name} already defined on line {module.structs[res.name].lineno}'
# )
module.structs[res.name] = res #
constructor = StructConstructor(res) # module.structs[res.name] = res
module.functions[constructor.name] = constructor # constructor = StructConstructor(res)
# module.functions[constructor.name] = constructor
if isinstance(res, Function): if isinstance(res, Function):
if res.name in module.functions: if res.name in module.functions:
@ -141,7 +127,7 @@ class OurVisitor:
if not arg.annotation: if not arg.annotation:
_raise_static_error(node, 'Type is required') _raise_static_error(node, 'Type is required')
function.posonlyargs.append(( function.posonlyargs.append(FunctionParam(
arg.arg, arg.arg,
self.visit_type(module, arg.annotation), self.visit_type(module, arg.annotation),
)) ))
@ -167,7 +153,7 @@ class OurVisitor:
function.imported = True function.imported = True
if node.returns: if node.returns:
function.returns = self.visit_type(module, node.returns) function.returns_str = self.visit_type(module, node.returns)
_not_implemented(not node.type_comment, 'FunctionDef.type_comment') _not_implemented(not node.type_comment, 'FunctionDef.type_comment')
@ -195,6 +181,7 @@ class OurVisitor:
if stmt.simple != 1: if stmt.simple != 1:
raise NotImplementedError('Class with non-simple arguments') raise NotImplementedError('Class with non-simple arguments')
raise NotImplementedError('TODO: Broken after new type system')
member = TypeStructMember(stmt.target.id, self.visit_type(module, stmt.annotation), offset) member = TypeStructMember(stmt.target.id, self.visit_type(module, stmt.annotation), offset)
struct.members.append(member) struct.members.append(member)
@ -208,34 +195,22 @@ class OurVisitor:
if not isinstance(node.target.ctx, ast.Store): if not isinstance(node.target.ctx, ast.Store):
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
exp_type = self.visit_type(module, node.annotation) if isinstance(node.value, ast.Constant):
return ModuleConstantDef(
if isinstance(exp_type, TypeInt32):
if not isinstance(node.value, ast.Constant):
_raise_static_error(node, 'Must be constant')
constant = ModuleConstantDef(
node.target.id, node.target.id,
node.lineno, node.lineno,
exp_type, self.visit_type(module, node.annotation),
self.visit_Module_Constant(module, exp_type, node.value), self.visit_Module_Constant(module, node.value),
None, None,
) )
return constant
if isinstance(exp_type, TypeTuple):
if not isinstance(node.value, ast.Tuple):
_raise_static_error(node, 'Must be tuple')
if len(exp_type.members) != len(node.value.elts):
_raise_static_error(node, 'Invalid number of tuple values')
if isinstance(node.value, ast.Tuple):
tuple_data = [ tuple_data = [
self.visit_Module_Constant(module, mem.type, arg_node) self.visit_Module_Constant(module, arg_node)
for arg_node, mem in zip(node.value.elts, exp_type.members) for arg_node in node.value.elts
if isinstance(arg_node, ast.Constant) if isinstance(arg_node, ast.Constant)
] ]
if len(exp_type.members) != len(tuple_data): if len(node.value.elts) != len(tuple_data):
_raise_static_error(node, 'Tuple arguments must be constants') _raise_static_error(node, 'Tuple arguments must be constants')
# Allocate the data # Allocate the data
@ -246,40 +221,69 @@ class OurVisitor:
return ModuleConstantDef( return ModuleConstantDef(
node.target.id, node.target.id,
node.lineno, node.lineno,
exp_type, self.visit_type(module, node.annotation),
ConstantTuple(exp_type, tuple_data), ConstantTuple(tuple_data),
data_block, data_block,
) )
if isinstance(exp_type, TypeStaticArray): raise NotImplementedError('TODO: Broken after new typing system')
if not isinstance(node.value, ast.Tuple):
_raise_static_error(node, 'Must be static array')
if len(exp_type.members) != len(node.value.elts): # if isinstance(exp_type, TypeTuple):
_raise_static_error(node, 'Invalid number of static array values') # if not isinstance(node.value, ast.Tuple):
# _raise_static_error(node, 'Must be tuple')
static_array_data = [ #
self.visit_Module_Constant(module, exp_type.member_type, arg_node) # if len(exp_type.members) != len(node.value.elts):
for arg_node in node.value.elts # _raise_static_error(node, 'Invalid number of tuple values')
if isinstance(arg_node, ast.Constant) #
] # tuple_data = [
if len(exp_type.members) != len(static_array_data): # self.visit_Module_Constant(module, arg_node)
_raise_static_error(node, 'Static array arguments must be constants') # for arg_node, mem in zip(node.value.elts, exp_type.members)
# if isinstance(arg_node, ast.Constant)
# Allocate the data # ]
data_block = ModuleDataBlock(static_array_data) # if len(exp_type.members) != len(tuple_data):
module.data.blocks.append(data_block) # _raise_static_error(node, 'Tuple arguments must be constants')
#
# Then return the constant as a pointer # # Allocate the data
return ModuleConstantDef( # data_block = ModuleDataBlock(tuple_data)
node.target.id, # module.data.blocks.append(data_block)
node.lineno, #
exp_type, # # Then return the constant as a pointer
ConstantStaticArray(exp_type, static_array_data), # return ModuleConstantDef(
data_block, # node.target.id,
) # node.lineno,
# exp_type,
raise NotImplementedError(f'{node} on Module AnnAssign') # ConstantTuple(tuple_data),
# data_block,
# )
#
# if isinstance(exp_type, TypeStaticArray):
# if not isinstance(node.value, ast.Tuple):
# _raise_static_error(node, 'Must be static array')
#
# if len(exp_type.members) != len(node.value.elts):
# _raise_static_error(node, 'Invalid number of static array values')
#
# static_array_data = [
# self.visit_Module_Constant(module, arg_node)
# for arg_node in node.value.elts
# if isinstance(arg_node, ast.Constant)
# ]
# if len(exp_type.members) != len(static_array_data):
# _raise_static_error(node, 'Static array arguments must be constants')
#
# # Allocate the data
# data_block = ModuleDataBlock(static_array_data)
# module.data.blocks.append(data_block)
#
# # Then return the constant as a pointer
# return ModuleConstantDef(
# node.target.id,
# node.lineno,
# ConstantStaticArray(static_array_data),
# data_block,
# )
#
# raise NotImplementedError(f'{node} on Module AnnAssign')
def visit_Module_stmt(self, module: Module, node: ast.stmt) -> None: def visit_Module_stmt(self, module: Module, node: ast.stmt) -> None:
if isinstance(node, ast.FunctionDef): if isinstance(node, ast.FunctionDef):
@ -297,7 +301,10 @@ class OurVisitor:
def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None: def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None:
function = module.functions[node.name] function = module.functions[node.name]
our_locals = dict(function.posonlyargs) our_locals: OurLocals = {
x.name: x
for x in function.posonlyargs
}
for stmt in node.body: for stmt in node.body:
function.statements.append( function.statements.append(
@ -311,12 +318,12 @@ class OurVisitor:
_raise_static_error(node, 'Return must have an argument') _raise_static_error(node, 'Return must have an argument')
return StatementReturn( return StatementReturn(
self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.value) self.visit_Module_FunctionDef_expr(module, function, our_locals, node.value)
) )
if isinstance(node, ast.If): if isinstance(node, ast.If):
result = StatementIf( result = StatementIf(
self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.test) self.visit_Module_FunctionDef_expr(module, function, our_locals, node.test)
) )
for stmt in node.body: for stmt in node.body:
@ -336,7 +343,7 @@ class OurVisitor:
raise NotImplementedError(f'{node} as stmt in FunctionDef') raise NotImplementedError(f'{node} as stmt in FunctionDef')
def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.expr) -> Expression: def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, node: ast.expr) -> Expression:
if isinstance(node, ast.BinOp): if isinstance(node, ast.BinOp):
if isinstance(node.op, ast.Add): if isinstance(node.op, ast.Add):
operator = '+' operator = '+'
@ -361,10 +368,9 @@ class OurVisitor:
# e.g. you can do `"hello" * 3` with the code below (yet) # e.g. you can do `"hello" * 3` with the code below (yet)
return BinaryOp( return BinaryOp(
exp_type,
operator, operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right),
) )
if isinstance(node, ast.UnaryOp): if isinstance(node, ast.UnaryOp):
@ -376,9 +382,8 @@ class OurVisitor:
raise NotImplementedError(f'Operator {node.op}') raise NotImplementedError(f'Operator {node.op}')
return UnaryOp( return UnaryOp(
exp_type,
operator, operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.operand),
) )
if isinstance(node, ast.Compare): if isinstance(node, ast.Compare):
@ -398,28 +403,27 @@ class OurVisitor:
# e.g. you can do `"hello" * 3` with the code below (yet) # e.g. you can do `"hello" * 3` with the code below (yet)
return BinaryOp( return BinaryOp(
exp_type,
operator, operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]),
) )
if isinstance(node, ast.Call): if isinstance(node, ast.Call):
return self.visit_Module_FunctionDef_Call(module, function, our_locals, exp_type, node) return self.visit_Module_FunctionDef_Call(module, function, our_locals, node)
if isinstance(node, ast.Constant): if isinstance(node, ast.Constant):
return self.visit_Module_Constant( return self.visit_Module_Constant(
module, exp_type, node, module, node,
) )
if isinstance(node, ast.Attribute): if isinstance(node, ast.Attribute):
return self.visit_Module_FunctionDef_Attribute( return self.visit_Module_FunctionDef_Attribute(
module, function, our_locals, exp_type, node, module, function, our_locals, node,
) )
if isinstance(node, ast.Subscript): if isinstance(node, ast.Subscript):
return self.visit_Module_FunctionDef_Subscript( return self.visit_Module_FunctionDef_Subscript(
module, function, our_locals, exp_type, node, module, function, our_locals, node,
) )
if isinstance(node, ast.Name): if isinstance(node, ast.Name):
@ -427,45 +431,41 @@ class OurVisitor:
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
if node.id in our_locals: if node.id in our_locals:
act_type = our_locals[node.id] param = our_locals[node.id]
if exp_type != act_type: return VariableReference(param)
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(act_type)}')
return VariableReference(act_type, node.id)
if node.id in module.constant_defs: if node.id in module.constant_defs:
cdef = module.constant_defs[node.id] cdef = module.constant_defs[node.id]
if exp_type != cdef.type: return VariableReference(cdef)
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(cdef.type)}')
return ModuleConstantReference(exp_type, cdef)
_raise_static_error(node, f'Undefined variable {node.id}') _raise_static_error(node, f'Undefined variable {node.id}')
if isinstance(node, ast.Tuple): if isinstance(node, ast.Tuple):
if not isinstance(node.ctx, ast.Load): raise NotImplementedError('TODO: Broken after new type system')
_raise_static_error(node, 'Must be load context')
if isinstance(exp_type, TypeTuple): # if not isinstance(node.ctx, ast.Load):
if len(exp_type.members) != len(node.elts): # _raise_static_error(node, 'Must be load context')
_raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given') #
# if isinstance(exp_type, TypeTuple):
tuple_constructor = TupleConstructor(exp_type) # if len(exp_type.members) != len(node.elts):
# _raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given')
func = module.functions[tuple_constructor.name] #
# tuple_constructor = TupleConstructor(exp_type)
result = FunctionCall(func) #
result.arguments = [ # func = module.functions[tuple_constructor.name]
self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node) #
for arg_node, mem in zip(node.elts, exp_type.members) # result = FunctionCall(func)
] # result.arguments = [
return result # self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node)
# for arg_node, mem in zip(node.elts, exp_type.members)
_raise_static_error(node, f'Expression is expecting a {codestyle.type_(exp_type)}, not a tuple') # ]
# return result
#
# _raise_static_error(node, f'Expression is expecting a {codestyle.type_(exp_type)}, not a tuple')
raise NotImplementedError(f'{node} as expr in FunctionDef') raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]: def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]:
if node.keywords: if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet? _raise_static_error(node, 'Keyword calling not supported') # Yet?
@ -474,48 +474,37 @@ class OurVisitor:
if not isinstance(node.func.ctx, ast.Load): if not isinstance(node.func.ctx, ast.Load):
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
if node.func.id in module.structs: # if node.func.id in module.structs:
struct = module.structs[node.func.id] # raise NotImplementedError('TODO: Broken after new type system')
struct_constructor = StructConstructor(struct) # struct = module.structs[node.func.id]
# struct_constructor = StructConstructor(struct)
func = module.functions[struct_constructor.name] #
elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: # func = module.functions[struct_constructor.name]
if not isinstance(exp_type, (TypeFloat32, TypeFloat64, )): if node.func.id in WEBASSEMBLY_BUILTIN_FLOAT_OPS:
_raise_static_error(node, f'Cannot make {node.func.id} result in {codestyle.type_(exp_type)}')
if 1 != len(node.args): if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
return UnaryOp( return UnaryOp(
exp_type,
'sqrt', 'sqrt',
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]),
) )
elif node.func.id == 'u32': elif node.func.id == 'u32':
if not isinstance(exp_type, TypeUInt32):
_raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}')
if 1 != len(node.args): if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
# FIXME: This is a stub, proper casting is todo # FIXME: This is a stub, proper casting is todo
return UnaryOp( return UnaryOp(
exp_type,
'cast', 'cast',
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['u8'], node.args[0]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]),
) )
elif node.func.id == 'len': elif node.func.id == 'len':
if not isinstance(exp_type, TypeInt32):
_raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}')
if 1 != len(node.args): if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
return UnaryOp( return UnaryOp(
exp_type,
'len', 'len',
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]),
) )
elif node.func.id == 'foldl': elif node.func.id == 'foldl':
# TODO: This should a much more generic function! # TODO: This should a much more generic function!
@ -538,21 +527,13 @@ class OurVisitor:
if 2 != len(func.posonlyargs): if 2 != len(func.posonlyargs):
_raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given') _raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
if exp_type.__class__ != func.returns.__class__: raise NotImplementedError('TODO: Broken after new type system')
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
if func.returns.__class__ != func.posonlyargs[0][1].__class__:
_raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0][1])}')
if module.types['u8'].__class__ != func.posonlyargs[1][1].__class__:
_raise_static_error(node, 'Only folding over bytes (u8) is supported at this time')
return Fold( return Fold(
exp_type,
Fold.Direction.LEFT, Fold.Direction.LEFT,
func, func,
self.visit_Module_FunctionDef_expr(module, function, our_locals, func.returns, node.args[1]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[2]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
) )
else: else:
if node.func.id not in module.functions: if node.func.id not in module.functions:
@ -560,49 +541,46 @@ class OurVisitor:
func = module.functions[node.func.id] func = module.functions[node.func.id]
if func.returns != exp_type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
if len(func.posonlyargs) != len(node.args): if len(func.posonlyargs) != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given') _raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given')
result = FunctionCall(func) result = FunctionCall(func)
result.arguments.extend( result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_type, arg_expr) self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
for arg_expr, (_, arg_type) in zip(node.args, func.posonlyargs) for arg_expr, param in zip(node.args, func.posonlyargs)
) )
return result return result
def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Attribute) -> Expression: def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression:
del module raise NotImplementedError('Broken after new type system')
del function # del module
# del function
#
# if not isinstance(node.value, ast.Name):
# _raise_static_error(node, 'Must reference a name')
#
# if not isinstance(node.ctx, ast.Load):
# _raise_static_error(node, 'Must be load context')
#
# if not node.value.id in our_locals:
# _raise_static_error(node, f'Undefined variable {node.value.id}')
#
# param = our_locals[node.value.id]
#
# node_typ = param.type
# if not isinstance(node_typ, TypeStruct):
# _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
#
# member = node_typ.get_member(node.attr)
# if member is None:
# _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}')
#
# return AccessStructMember(
# VariableReference(param),
# member,
# )
if not isinstance(node.value, ast.Name): def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression:
_raise_static_error(node, 'Must reference a name')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if not node.value.id in our_locals:
_raise_static_error(node, f'Undefined variable {node.value.id}')
node_typ = our_locals[node.value.id]
if not isinstance(node_typ, TypeStruct):
_raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
member = node_typ.get_member(node.attr)
if member is None:
_raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}')
if exp_type != member.type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}')
return AccessStructMember(
VariableReference(node_typ, node.value.id),
member,
)
def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Subscript) -> Expression:
if not isinstance(node.value, ast.Name): if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name') _raise_static_error(node, 'Must reference a name')
@ -612,154 +590,93 @@ class OurVisitor:
if not isinstance(node.ctx, ast.Load): if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
varref: Union[ModuleConstantReference, VariableReference] varref: VariableReference
if node.value.id in our_locals: if node.value.id in our_locals:
node_typ = our_locals[node.value.id] param = our_locals[node.value.id]
varref = VariableReference(node_typ, node.value.id) varref = VariableReference(param)
elif node.value.id in module.constant_defs: elif node.value.id in module.constant_defs:
constant_def = module.constant_defs[node.value.id] constant_def = module.constant_defs[node.value.id]
node_typ = constant_def.type varref = VariableReference(constant_def)
varref = ModuleConstantReference(node_typ, constant_def)
else: else:
_raise_static_error(node, f'Undefined variable {node.value.id}') _raise_static_error(node, f'Undefined variable {node.value.id}')
slice_expr = self.visit_Module_FunctionDef_expr( slice_expr = self.visit_Module_FunctionDef_expr(
module, function, our_locals, module.types['u32'], node.slice.value, module, function, our_locals, node.slice.value,
) )
if isinstance(node_typ, TypeBytes): return Subscript(varref, slice_expr)
t_u8 = module.types['u8']
if exp_type != t_u8:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{codestyle.expression(slice_expr)}] is actually {codestyle.type_(t_u8)}')
if isinstance(varref, ModuleConstantReference): # if isinstance(node_typ, TypeBytes):
raise NotImplementedError(f'{node} from module constant') # if isinstance(varref, ModuleConstantReference):
# raise NotImplementedError(f'{node} from module constant')
#
# return AccessBytesIndex(
# varref,
# slice_expr,
# )
#
# if isinstance(node_typ, TypeTuple):
# if not isinstance(slice_expr, ConstantPrimitive):
# _raise_static_error(node, 'Must subscript using a constant index')
#
# idx = slice_expr.value
#
# if not isinstance(idx, int):
# _raise_static_error(node, 'Must subscript using a constant integer index')
#
# if not (0 <= idx < len(node_typ.members)):
# _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}')
#
# tuple_member = node_typ.members[idx]
#
# if isinstance(varref, ModuleConstantReference):
# raise NotImplementedError(f'{node} from module constant')
#
# return AccessTupleMember(
# varref,
# tuple_member,
# )
#
# if isinstance(node_typ, TypeStaticArray):
# if not isinstance(slice_expr, ConstantPrimitive):
# return AccessStaticArrayMember(
# varref,
# node_typ,
# slice_expr,
# )
#
# idx = slice_expr.value
#
# if not isinstance(idx, int):
# _raise_static_error(node, 'Must subscript using an integer index')
#
# if not (0 <= idx < len(node_typ.members)):
# _raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}')
#
# static_array_member = node_typ.members[idx]
#
# return AccessStaticArrayMember(
# varref,
# node_typ,
# static_array_member,
# )
#
# _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}')
return AccessBytesIndex( def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive:
t_u8,
varref,
slice_expr,
)
if isinstance(node_typ, TypeTuple):
if not isinstance(slice_expr, ConstantUInt32):
_raise_static_error(node, 'Must subscript using a constant index')
idx = slice_expr.value
if len(node_typ.members) <= idx:
_raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}')
tuple_member = node_typ.members[idx]
if exp_type != tuple_member.type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(tuple_member.type)}')
if isinstance(varref, ModuleConstantReference):
raise NotImplementedError(f'{node} from module constant')
return AccessTupleMember(
varref,
tuple_member,
)
if isinstance(node_typ, TypeStaticArray):
if exp_type != node_typ.member_type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}')
if not isinstance(slice_expr, ConstantInt32):
return AccessStaticArrayMember(
varref,
node_typ,
slice_expr,
)
idx = slice_expr.value
if len(node_typ.members) <= idx:
_raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}')
static_array_member = node_typ.members[idx]
return AccessStaticArrayMember(
varref,
node_typ,
static_array_member,
)
_raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}')
def visit_Module_Constant(self, module: Module, exp_type: TypeBase, node: ast.Constant) -> Constant:
del module del module
_not_implemented(node.kind is None, 'Constant.kind') _not_implemented(node.kind is None, 'Constant.kind')
if isinstance(exp_type, TypeUInt8): if isinstance(node.value, (int, float, )):
if not isinstance(node.value, int): return ConstantPrimitive(node.value)
_raise_static_error(node, 'Expected integer value')
if node.value < 0 or node.value > 255: raise NotImplementedError(f'{node.value} as constant')
_raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}')
return ConstantUInt8(exp_type, node.value) def visit_type(self, module: Module, node: ast.expr) -> str:
if isinstance(exp_type, TypeUInt32):
if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value')
if node.value < 0 or node.value > 4294967295:
_raise_static_error(node, 'Integer value out of range')
return ConstantUInt32(exp_type, node.value)
if isinstance(exp_type, TypeUInt64):
if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value')
if node.value < 0 or node.value > 18446744073709551615:
_raise_static_error(node, 'Integer value out of range')
return ConstantUInt64(exp_type, node.value)
if isinstance(exp_type, TypeInt32):
if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value')
if node.value < -2147483648 or node.value > 2147483647:
_raise_static_error(node, 'Integer value out of range')
return ConstantInt32(exp_type, node.value)
if isinstance(exp_type, TypeInt64):
if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value')
if node.value < -9223372036854775808 or node.value > 9223372036854775807:
_raise_static_error(node, 'Integer value out of range')
return ConstantInt64(exp_type, node.value)
if isinstance(exp_type, TypeFloat32):
if not isinstance(node.value, (float, int, )):
_raise_static_error(node, 'Expected float value')
# FIXME: Range check
return ConstantFloat32(exp_type, node.value)
if isinstance(exp_type, TypeFloat64):
if not isinstance(node.value, (float, int, )):
_raise_static_error(node, 'Expected float value')
# FIXME: Range check
return ConstantFloat64(exp_type, node.value)
raise NotImplementedError(f'{node} as const for type {exp_type}')
def visit_type(self, module: Module, node: ast.expr) -> TypeBase:
if isinstance(node, ast.Constant): if isinstance(node, ast.Constant):
if node.value is None: if node.value is None:
return module.types['None'] return 'None'
_raise_static_error(node, f'Unrecognized type {node.value}') _raise_static_error(node, f'Unrecognized type {node.value}')
@ -767,8 +684,10 @@ class OurVisitor:
if not isinstance(node.ctx, ast.Load): if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
if node.id in module.types: if node.id in BUILTIN_TYPES:
return module.types[node.id] return node.id
raise NotImplementedError('TODO: Broken after type system')
if node.id in module.structs: if node.id in module.structs:
return module.structs[node.id] return module.structs[node.id]
@ -787,50 +706,35 @@ class OurVisitor:
if not isinstance(node.ctx, ast.Load): if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
if node.value.id in module.types: if node.value.id not in BUILTIN_TYPES: # FIXME: Tuple of tuples?
member_type = module.types[node.value.id]
else:
_raise_static_error(node, f'Unrecognized type {node.value.id}') _raise_static_error(node, f'Unrecognized type {node.value.id}')
type_static_array = TypeStaticArray(member_type) return f'{node.value.id}[{node.slice.value.value}]'
offset = 0
for idx in range(node.slice.value.value):
static_array_member = TypeStaticArrayMember(idx, offset)
type_static_array.members.append(static_array_member)
offset += member_type.alloc_size()
key = f'{node.value.id}[{node.slice.value.value}]'
if key not in module.types:
module.types[key] = type_static_array
return module.types[key]
if isinstance(node, ast.Tuple): if isinstance(node, ast.Tuple):
if not isinstance(node.ctx, ast.Load): raise NotImplementedError('TODO: Broken after new type system')
_raise_static_error(node, 'Must be load context')
type_tuple = TypeTuple() # if not isinstance(node.ctx, ast.Load):
# _raise_static_error(node, 'Must be load context')
offset = 0 #
# type_tuple = TypeTuple()
for idx, elt in enumerate(node.elts): #
tuple_member = TypeTupleMember(idx, self.visit_type(module, elt), offset) # offset = 0
#
type_tuple.members.append(tuple_member) # for idx, elt in enumerate(node.elts):
offset += tuple_member.type.alloc_size() # tuple_member = TypeTupleMember(idx, self.visit_type(module, elt), offset)
#
key = type_tuple.render_internal_name() # type_tuple.members.append(tuple_member)
# offset += tuple_member.type.alloc_size()
if key not in module.types: #
module.types[key] = type_tuple # key = type_tuple.render_internal_name()
constructor = TupleConstructor(type_tuple) #
module.functions[constructor.name] = constructor # if key not in module.types:
# module.types[key] = type_tuple
return module.types[key] # constructor = TupleConstructor(type_tuple)
# module.functions[constructor.name] = constructor
#
# return module.types[key]
raise NotImplementedError(f'{node} as type') raise NotImplementedError(f'{node} as type')

View File

@ -26,7 +26,7 @@ def __find_free_block__(g: Generator, alloc_size: i32) -> i32:
g.i32.const(0) g.i32.const(0)
g.return_() g.return_()
del alloc_size # TODO del alloc_size # TODO: Actual implement using a previously freed block
g.unreachable() g.unreachable()
return i32('return') # To satisfy mypy return i32('return') # To satisfy mypy

192
phasm/typer.py Normal file
View File

@ -0,0 +1,192 @@
"""
Type checks and enriches the given ast
"""
from . import ourlang
from .exceptions import TypingError
from .typing import (
Context,
TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeConstraintSubscript,
TypeVar,
from_str,
)
def phasm_type(inp: ourlang.Module) -> None:
module(inp)
def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar:
if isinstance(inp, ourlang.ConstantPrimitive):
result = ctx.new_var()
if isinstance(inp.value, int):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
# Need at least this many bits to store this constant value
result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2))
# Don't dictate anything about signedness - you can use a signed
# constant in an unsigned variable if the bits fit
result.add_constraint(TypeConstraintSigned(None))
result.add_location(str(inp.value))
inp.type_var = result
return result
if isinstance(inp.value, float):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
# We don't have fancy logic here to detect if the float constant
# fits in the given type. There a number of edge cases to consider,
# before implementing this.
# 1) It may fit anyhow
# e.g., if the user has 3.14 as a float constant, neither a
# f32 nor a f64 can really fit this value. But does that mean
# we should throw an error?
# If we'd implement it, we'd want to convert it to hex using
# inp.value.hex(), which would give us the mantissa and exponent.
# We can use those to determine what bit size the value should be in.
# If that doesn't work out, we'd need another way to calculate the
# difference between what was written and what actually gets stored
# in memory, and warn if the difference is beyond a treshold.
result.add_location(str(inp.value))
inp.type_var = result
return result
raise NotImplementedError(constant, inp, inp.value)
if isinstance(inp, ourlang.ConstantTuple):
result = ctx.new_var()
result.add_constraint(TypeConstraintSubscript(members=(
constant(ctx, x)
for x in inp.value
)))
result.add_location(str(inp.value))
inp.type_var = result
return result
raise NotImplementedError(constant, inp)
def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
if isinstance(inp, ourlang.Constant):
return constant(ctx, inp)
if isinstance(inp, ourlang.VariableReference):
assert inp.variable.type_var is not None
inp.type_var = inp.variable.type_var
return inp.variable.type_var
if isinstance(inp, ourlang.UnaryOp):
# TODO: Simplified version
if inp.operator not in ('sqrt', ):
raise NotImplementedError(expression, inp, inp.operator)
right = expression(ctx, inp.right)
inp.type_var = right
return right
if isinstance(inp, ourlang.BinaryOp):
if inp.operator in ('+', '-', '*', '|', '&', '^'):
left = expression(ctx, inp.left)
right = expression(ctx, inp.right)
ctx.unify(left, right)
inp.type_var = left
return left
if inp.operator in ('<<', '>>', ):
inp.type_var = ctx.new_var()
inp.type_var.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
inp.type_var.add_constraint(TypeConstraintBitWidth(oneof=(32, 64, )))
inp.type_var.add_constraint(TypeConstraintSigned(False))
left = expression(ctx, inp.left)
right = expression(ctx, inp.right)
ctx.unify(left, right)
ctx.unify(inp.type_var, left)
return left
raise NotImplementedError(expression, inp, inp.operator)
if isinstance(inp, ourlang.FunctionCall):
assert inp.function.returns_type_var is not None
for param, expr in zip(inp.function.posonlyargs, inp.arguments):
assert param.type_var is not None
arg = expression(ctx, expr)
ctx.unify(param.type_var, arg)
return inp.function.returns_type_var
if isinstance(inp, ourlang.Subscript):
if not isinstance(inp.index, ourlang.ConstantPrimitive):
raise NotImplementedError(expression, inp, inp.index)
if not isinstance(inp.index.value, int):
raise NotImplementedError(expression, inp, inp.index.value)
expression(ctx, inp.varref)
assert inp.varref.type_var is not None
# TODO: I'd much rather resolve this using the narrow functions
tc_subs = inp.varref.type_var.get_constraint(TypeConstraintSubscript)
if tc_subs is None:
raise TypingError(f'Type cannot be subscripted: {inp.varref.type_var}') from None
try:
# TODO: I'd much rather resolve this using the narrow functions
member = tc_subs.members[inp.index.value]
except IndexError:
raise TypingError(f'Type cannot be subscripted with index {inp.index.value}: {inp.varref.type_var}') from None
inp.type_var = member
return member
raise NotImplementedError(expression, inp)
def function(ctx: Context, inp: ourlang.Function) -> None:
if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn):
raise NotImplementedError('Functions with not just a return statement')
typ = expression(ctx, inp.statements[0].value)
assert inp.returns_type_var is not None
ctx.unify(inp.returns_type_var, typ)
def module_constant_def(ctx: Context, inp: ourlang.ModuleConstantDef) -> None:
constant(ctx, inp.constant)
if inp.type_str is None:
inp.type_var = ctx.new_var()
else:
inp.type_var = from_str(ctx, inp.type_str)
assert inp.constant.type_var is not None
ctx.unify(inp.type_var, inp.constant.type_var)
def module(inp: ourlang.Module) -> None:
ctx = Context()
for func in inp.functions.values():
func.returns_type_var = from_str(ctx, func.returns_str, f'{func.name}.(returns)')
for param in func.posonlyargs:
param.type_var = from_str(ctx, param.type_str, f'{func.name}.{param.name}')
for cdef in inp.constant_defs.values():
module_constant_def(ctx, cdef)
for func in inp.functions.values():
function(ctx, func)

View File

@ -1,7 +1,13 @@
""" """
The phasm type system The phasm type system
""" """
from typing import Optional, List from typing import Callable, Dict, Iterable, Optional, List, Set, Type
from typing import TypeVar as MyPyTypeVar
import enum
import re
from .exceptions import TypingError
class TypeBase: class TypeBase:
""" """
@ -15,88 +21,6 @@ class TypeBase:
""" """
raise NotImplementedError(self, 'alloc_size') raise NotImplementedError(self, 'alloc_size')
class TypeNone(TypeBase):
"""
The None (or Void) type
"""
__slots__ = ()
class TypeBool(TypeBase):
"""
The boolean type
"""
__slots__ = ()
class TypeUInt8(TypeBase):
"""
The Integer type, unsigned and 8 bits wide
Note that under the hood we need to use i32 to represent
these values in expressions. So we need to add some operations
to make sure the math checks out.
So while this does save bytes in memory, it may not actually
speed up or improve your code.
"""
__slots__ = ()
def alloc_size(self) -> int:
return 4 # Int32 under the hood
class TypeUInt32(TypeBase):
"""
The Integer type, unsigned and 32 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 4
class TypeUInt64(TypeBase):
"""
The Integer type, unsigned and 64 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 8
class TypeInt32(TypeBase):
"""
The Integer type, signed and 32 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 4
class TypeInt64(TypeBase):
"""
The Integer type, signed and 64 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 8
class TypeFloat32(TypeBase):
"""
The Float type, 32 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 4
class TypeFloat64(TypeBase):
"""
The Float type, 64 bits wide
"""
__slots__ = ()
def alloc_size(self) -> int:
return 8
class TypeBytes(TypeBase): class TypeBytes(TypeBase):
""" """
The bytes type The bytes type
@ -200,3 +124,487 @@ class TypeStruct(TypeBase):
x.type.alloc_size() x.type.alloc_size()
for x in self.members for x in self.members
) )
## NEW STUFF BELOW
# This error can also mean that the typer somewhere forgot to write a type
# back to the AST. If so, we need to fix the typer.
ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method'
class TypingNarrowProtoError(TypingError):
"""
A proto error when trying to narrow two types
This gets turned into a TypingNarrowError by the unify method
"""
# FIXME: Use consistent naming for unify / narrow / entangle
class TypingNarrowError(TypingError):
"""
An error when trying to unify two Type Variables
"""
def __init__(self, l: 'TypeVar', r: 'TypeVar', msg: str) -> None:
super().__init__(
f'Cannot narrow types {l} and {r}: {msg}'
)
class TypeConstraintBase:
"""
Base class for classes implementing a contraint on a type
"""
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBase':
raise NotImplementedError('narrow', self, other)
class TypeConstraintPrimitive(TypeConstraintBase):
"""
This contraint on a type defines its primitive shape
"""
__slots__ = ('primitive', )
class Primitive(enum.Enum):
"""
The primitive ID
"""
INT = 0
FLOAT = 1
STATIC_ARRAY = 10
primitive: Primitive
def __init__(self, primitive: Primitive) -> None:
self.primitive = primitive
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive':
if not isinstance(other, TypeConstraintPrimitive):
raise Exception('Invalid comparison')
if self.primitive != other.primitive:
raise TypingNarrowProtoError('Primitive does not match')
return TypeConstraintPrimitive(self.primitive)
def __repr__(self) -> str:
return f'Primitive={self.primitive.name}'
class TypeConstraintSigned(TypeConstraintBase):
"""
Contraint on whether a signed value can be used or not, or whether
a value can be used in a signed expression
"""
__slots__ = ('signed', )
signed: Optional[bool]
def __init__(self, signed: Optional[bool]) -> None:
self.signed = signed
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintSigned':
if not isinstance(other, TypeConstraintSigned):
raise Exception('Invalid comparison')
if other.signed is None:
return TypeConstraintSigned(self.signed)
if self.signed is None:
return TypeConstraintSigned(other.signed)
if self.signed is not other.signed:
raise TypingNarrowProtoError('Signed does not match')
return TypeConstraintSigned(self.signed)
def __repr__(self) -> str:
return f'Signed={self.signed}'
class TypeConstraintBitWidth(TypeConstraintBase):
"""
Contraint on how many bits an expression has or can possibly have
"""
__slots__ = ('oneof', )
oneof: Set[int]
def __init__(self, *, oneof: Optional[Iterable[int]] = None, minb: Optional[int] = None, maxb: Optional[int] = None) -> None:
# For now, support up to 64 bits values
self.oneof = set(oneof) if oneof is not None else set(range(1, 65))
if minb is not None:
self.oneof = {
x
for x in self.oneof
if minb <= x
}
if maxb is not None:
self.oneof = {
x
for x in self.oneof
if x <= maxb
}
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth':
if not isinstance(other, TypeConstraintBitWidth):
raise Exception('Invalid comparison')
new_oneof = self.oneof & other.oneof
if not new_oneof:
raise TypingNarrowProtoError('Memory width cannot be resolved')
return TypeConstraintBitWidth(oneof=new_oneof)
def __repr__(self) -> str:
result = 'BitWidth='
items = list(sorted(self.oneof))
if not items:
return result
while items:
itm = items.pop(0)
result += str(itm)
cnt = 0
while cnt < len(items) and items[cnt] == itm + cnt + 1:
cnt += 1
if cnt == 1:
result += ',' + str(items[0])
elif cnt > 1:
result += '..' + str(items[cnt - 1])
items = items[cnt:]
if items:
result += ','
return result
class TypeConstraintSubscript(TypeConstraintBase):
"""
Contraint on allowing a type to be subscripted
"""
__slots__ = ('members', )
members: List['TypeVar']
def __init__(self, *, members: Iterable['TypeVar']) -> None:
self.members = list(members)
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintSubscript':
if not isinstance(other, TypeConstraintSubscript):
raise Exception('Invalid comparison')
if len(self.members) != len(other.members):
raise TypingNarrowProtoError('Member count does not match')
newmembers = []
for smb, omb in zip(self.members, other.members):
nmb = ctx.new_var()
ctx.unify(nmb, smb)
ctx.unify(nmb, omb)
newmembers.append(nmb)
return TypeConstraintSubscript(members=newmembers)
def __repr__(self) -> str:
return 'Subscript=(' + ','.join(map(repr, self.members)) + ')'
TTypeConstraintClass = MyPyTypeVar('TTypeConstraintClass', bound=TypeConstraintBase)
class TypeVar:
"""
A type variable
"""
# FIXME: Explain the type system
__slots__ = ('ctx', 'ctx_id', )
ctx: 'Context'
ctx_id: int
def __init__(self, ctx: 'Context', ctx_id: int) -> None:
self.ctx = ctx
self.ctx_id = ctx_id
def add_constraint(self, newconst: TypeConstraintBase) -> None:
csts = self.ctx.var_constraints[self.ctx_id]
if newconst.__class__ in csts:
csts[newconst.__class__] = csts[newconst.__class__].narrow(self.ctx, newconst)
else:
csts[newconst.__class__] = newconst
def get_constraint(self, const_type: Type[TTypeConstraintClass]) -> Optional[TTypeConstraintClass]:
csts = self.ctx.var_constraints[self.ctx_id]
res = csts.get(const_type, None)
assert res is None or isinstance(res, const_type) # type hint
return res
def add_location(self, ref: str) -> None:
self.ctx.var_locations[self.ctx_id].add(ref)
def __repr__(self) -> str:
return (
'TypeVar<'
+ '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values()))
+ '; locations: '
+ ', '.join(sorted(self.ctx.var_locations[self.ctx_id]))
+ '>'
)
class Context:
"""
The context for a collection of type variables
"""
def __init__(self) -> None:
# Variables are unified (or entangled, if you will)
# that means that each TypeVar within a context has an ID,
# and all TypeVars with the same ID are the same TypeVar,
# even if they are a different instance
self.next_ctx_id = 1
self.vars_by_id: Dict[int, List[TypeVar]] = {}
# Store the TypeVar properties as a lookup
# so we can update these when unifying
self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {}
self.var_locations: Dict[int, Set[str]] = {}
def new_var(self) -> TypeVar:
ctx_id = self.next_ctx_id
self.next_ctx_id += 1
result = TypeVar(self, ctx_id)
self.vars_by_id[ctx_id] = [result]
self.var_constraints[ctx_id] = {}
self.var_locations[ctx_id] = set()
return result
def unify(self, l: Optional[TypeVar], r: Optional[TypeVar]) -> None:
# FIXME: Write method doc, find out why pylint doesn't error
assert l is not None, ASSERTION_ERROR
assert r is not None, ASSERTION_ERROR
assert l.ctx_id != r.ctx_id # Dunno if this'll happen, if so, just return
# Backup some values that we'll overwrite
l_ctx_id = l.ctx_id
r_ctx_id = r.ctx_id
l_r_var_list = self.vars_by_id[l_ctx_id] + self.vars_by_id[r_ctx_id]
# Create a new TypeVar, with the combined contraints
# and locations of the old ones
n = self.new_var()
try:
for const in self.var_constraints[l_ctx_id].values():
n.add_constraint(const)
for const in self.var_constraints[r_ctx_id].values():
n.add_constraint(const)
except TypingNarrowProtoError as exc:
raise TypingNarrowError(l, r, str(exc)) from None
self.var_locations[n.ctx_id] = self.var_locations[l_ctx_id] | self.var_locations[r_ctx_id]
# ##
# And unify (or entangle) the old ones
# First update the IDs, so they all point to the new list
for type_var in l_r_var_list:
type_var.ctx_id = n.ctx_id
# Update our registry of TypeVars by ID, so we can find them
# on the next unify
self.vars_by_id[n.ctx_id].extend(l_r_var_list)
# Then delete the old values for the now gone variables
# Do this last, so exceptions thrown in the code above
# still have a valid context
del self.var_constraints[l_ctx_id]
del self.var_constraints[r_ctx_id]
del self.var_locations[l_ctx_id]
del self.var_locations[r_ctx_id]
def simplify(inp: TypeVar) -> Optional[str]:
"""
Simplifies a TypeVar into a string that wasm can work with
and users can recognize
Should round trip with from_str
"""
tc_prim = inp.get_constraint(TypeConstraintPrimitive)
tc_bits = inp.get_constraint(TypeConstraintBitWidth)
tc_sign = inp.get_constraint(TypeConstraintSigned)
if tc_prim is None:
return None
primitive = tc_prim.primitive
if primitive is TypeConstraintPrimitive.Primitive.INT:
if tc_bits is None or tc_sign is None:
return None
if tc_sign.signed is None or len(tc_bits.oneof) != 1:
return None
bitwidth = next(iter(tc_bits.oneof))
if bitwidth not in (8, 32, 64):
return None
base = 'i' if tc_sign.signed else 'u'
return f'{base}{bitwidth}'
if primitive is TypeConstraintPrimitive.Primitive.FLOAT:
if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint
return None
if len(tc_bits.oneof) != 1:
return None
bitwidth = next(iter(tc_bits.oneof))
if bitwidth not in (32, 64):
return None
return f'f{bitwidth}'
if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
tc_subs = inp.get_constraint(TypeConstraintSubscript)
assert tc_subs is not None
assert tc_subs.members
sab = simplify(tc_subs.members[0])
if sab is None:
return None
return f'{sab}[{len(tc_subs.members)}]'
return None
def make_u8(ctx: Context) -> TypeVar:
"""
Makes a u8 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False))
result.add_location('u8')
return result
def make_u32(ctx: Context) -> TypeVar:
"""
Makes a u32 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False))
result.add_location('u32')
return result
def make_u64(ctx: Context) -> TypeVar:
"""
Makes a u64 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False))
result.add_location('u64')
return result
def make_i32(ctx: Context) -> TypeVar:
"""
Makes a i32 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True))
result.add_location('i32')
return result
def make_i64(ctx: Context) -> TypeVar:
"""
Makes a i64 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True))
result.add_location('i64')
return result
def make_f32(ctx: Context) -> TypeVar:
"""
Makes a f32 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_location('f32')
return result
def make_f64(ctx: Context) -> TypeVar:
"""
Makes a f64 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_location('f64')
return result
BUILTIN_TYPES: Dict[str, Callable[[Context], TypeVar]] = {
'u8': make_u8,
'u32': make_u32,
'u64': make_u64,
'i32': make_i32,
'i64': make_i64,
'f32': make_f32,
'f64': make_f64,
}
TYPE_MATCH_STATIC_ARRAY = re.compile(r'^([uif][0-9]+)\[([0-9]+)\]')
def from_str(ctx: Context, inp: str, location: Optional[str] = None) -> TypeVar:
"""
Creates a new TypeVar from the string
Should round trip with simplify
The location is a reference to where you found the string
in the source code.
This could be conidered part of parsing. Though that would give trouble
with the context creation.
"""
if inp in BUILTIN_TYPES:
result = BUILTIN_TYPES[inp](ctx)
if location is not None:
result.add_location(location)
return result
match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp)
if match:
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY))
result.add_constraint(TypeConstraintSubscript(members=(
# Make copies so they don't get entangled
# with each other.
from_str(ctx, match[1], match[1])
for _ in range(int(match[2]))
)))
result.add_location(inp)
if location is not None:
result.add_location(location)
return result
raise NotImplementedError(from_str, inp)

View File

@ -1,7 +1,7 @@
""" """
Helper functions to quickly generate WASM code Helper functions to quickly generate WASM code
""" """
from typing import Any, Dict, List, Optional, Type from typing import List, Optional
import functools import functools

View File

@ -1,5 +1,5 @@
[MASTER] [MASTER]
disable=C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223 disable=C0103,C0122,R0902,R0903,R0911,R0912,R0913,R0915,R1710,W0223
max-line-length=180 max-line-length=180
@ -7,4 +7,4 @@ max-line-length=180
good-names=g good-names=g
[tests] [tests]
disable=C0116, disable=C0116,R0201

View File

@ -0,0 +1,16 @@
"""
Constants for use in the tests
"""
ALL_INT_TYPES = ['u8', 'u32', 'u64', 'i32', 'i64']
COMPLETE_INT_TYPES = ['u32', 'u64', 'i32', 'i64']
ALL_FLOAT_TYPES = ['f32', 'f64']
COMPLETE_FLOAT_TYPES = ALL_FLOAT_TYPES
TYPE_MAP = {
**{x: int for x in ALL_INT_TYPES},
**{x: float for x in ALL_FLOAT_TYPES},
}
COMPLETE_PRIMITIVE_TYPES = COMPLETE_INT_TYPES + COMPLETE_FLOAT_TYPES

View File

@ -13,6 +13,7 @@ import wasmtime
from phasm.compiler import phasm_compile from phasm.compiler import phasm_compile
from phasm.parser import phasm_parse from phasm.parser import phasm_parse
from phasm.typer import phasm_type
from phasm import ourlang from phasm import ourlang
from phasm import wasm from phasm import wasm
@ -40,6 +41,7 @@ class RunnerBase:
Parses the Phasm code into an AST Parses the Phasm code into an AST
""" """
self.phasm_ast = phasm_parse(self.phasm_code) self.phasm_ast = phasm_parse(self.phasm_code)
phasm_type(self.phasm_ast)
def compile_ast(self) -> None: def compile_ast(self) -> None:
""" """

View File

View File

@ -0,0 +1,20 @@
import pytest
from phasm import typing as sut
class TestTypeConstraintBitWidth:
@pytest.mark.parametrize('oneof,exp', [
(set(), '', ),
({1}, '1', ),
({1,2}, '1,2', ),
({1,2,3}, '1..3', ),
({1,2,3,4}, '1..4', ),
({1,3}, '1,3', ),
({1,4}, '1,4', ),
({1,2,3,4,6,7,8,9}, '1..4,6..9', ),
])
def test_repr(self, oneof, exp):
mut_self = sut.TypeConstraintBitWidth(oneof=oneof)
assert ('BitWidth=' + exp) == repr(mut_self)

View File

@ -3,9 +3,9 @@ import struct
import pytest import pytest
from .helpers import Suite from ..helpers import Suite
@pytest.mark.integration_test @pytest.mark.slow_integration_test
def test_crc32(): def test_crc32():
# FIXME: Stub # FIXME: Stub
# crc = 0xFFFFFFFF # crc = 0xFFFFFFFF

View File

@ -1,6 +1,6 @@
import pytest import pytest
from .helpers import Suite from ..helpers import Suite
@pytest.mark.slow_integration_test @pytest.mark.slow_integration_test
def test_fib(): def test_fib():

View File

@ -1,70 +0,0 @@
import io
import pytest
from pywasm import binary
from pywasm import Runtime
from wasmer import wat2wasm
def run(code_wat):
code_wasm = wat2wasm(code_wat)
module = binary.Module.from_reader(io.BytesIO(code_wasm))
runtime = Runtime(module, {}, {})
out_put = runtime.exec('testEntry', [])
return (runtime, out_put)
@pytest.mark.parametrize('size,offset,exp_out_put', [
('32', 0, 0x3020100),
('32', 1, 0x4030201),
('64', 0, 0x706050403020100),
('64', 2, 0x908070605040302),
])
def test_i32_64_load(size, offset, exp_out_put):
code_wat = f"""
(module
(memory 1)
(data (memory 0) (i32.const 0) "\\00\\01\\02\\03\\04\\05\\06\\07\\08\\09\\10")
(func (export "testEntry") (result i{size})
i32.const {offset}
i{size}.load
return ))
"""
(_, out_put) = run(code_wat)
assert exp_out_put == out_put
def test_load_then_store():
code_wat = """
(module
(memory 1)
(data (memory 0) (i32.const 0) "\\04\\00\\00\\00")
(func (export "testEntry") (result i32) (local $my_memory_value i32)
;; Load i32 from address 0
i32.const 0
i32.load
;; Add 8 to the loaded value
i32.const 8
i32.add
local.set $my_memory_value
;; Store back to the memory
i32.const 0
local.get $my_memory_value
i32.store
;; Return something
i32.const 9
return ))
"""
(runtime, out_put) = run(code_wat)
assert 9 == out_put
assert (b'\x0c'+ b'\00' * 23) == runtime.store.mems[0].data[:24]

View File

View File

@ -2,8 +2,8 @@ import sys
import pytest import pytest
from .helpers import Suite, write_header from ..helpers import Suite, write_header
from .runners import RunnerPywasm from ..runners import RunnerPywasm
def setup_interpreter(phash_code: str) -> RunnerPywasm: def setup_interpreter(phash_code: str) -> RunnerPywasm:
runner = RunnerPywasm(phash_code) runner = RunnerPywasm(phash_code)

View File

@ -0,0 +1,53 @@
import pytest
from ..helpers import Suite
@pytest.mark.integration_test
def test_bytes_address():
code_py = """
@exported
def testEntry(f: bytes) -> bytes:
return f
"""
result = Suite(code_py).run_code(b'This is a test')
# THIS DEPENDS ON THE ALLOCATOR
# A different allocator will return a different value
assert 20 == result.returned_value
@pytest.mark.integration_test
def test_bytes_length():
code_py = """
@exported
def testEntry(f: bytes) -> i32:
return len(f)
"""
result = Suite(code_py).run_code(b'This is another test')
assert 20 == result.returned_value
@pytest.mark.integration_test
def test_bytes_index():
code_py = """
@exported
def testEntry(f: bytes) -> u8:
return f[8]
"""
result = Suite(code_py).run_code(b'This is another test')
assert 0x61 == result.returned_value
@pytest.mark.integration_test
def test_bytes_index_out_of_bounds():
code_py = """
@exported
def testEntry(f: bytes) -> u8:
return f[50]
"""
result = Suite(code_py).run_code(b'Short', b'Long' * 100)
assert 0 == result.returned_value

View File

@ -0,0 +1,71 @@
import pytest
from ..helpers import Suite
@pytest.mark.integration_test
@pytest.mark.parametrize('inp', [9, 10, 11, 12])
def test_if_simple(inp):
code_py = """
@exported
def testEntry(a: i32) -> i32:
if a > 10:
return 15
return 3
"""
exp_result = 15 if inp > 10 else 3
suite = Suite(code_py)
result = suite.run_code(inp)
assert exp_result == result.returned_value
@pytest.mark.integration_test
@pytest.mark.skip('Such a return is not how things should be')
def test_if_complex():
code_py = """
@exported
def testEntry(a: i32) -> i32:
if a > 10:
return 10
elif a > 0:
return a
else:
return 0
return -1 # Required due to function type
"""
suite = Suite(code_py)
assert 10 == suite.run_code(20).returned_value
assert 10 == suite.run_code(10).returned_value
assert 8 == suite.run_code(8).returned_value
assert 0 == suite.run_code(0).returned_value
assert 0 == suite.run_code(-1).returned_value
@pytest.mark.integration_test
def test_if_nested():
code_py = """
@exported
def testEntry(a: i32, b: i32) -> i32:
if a > 11:
if b > 11:
return 3
return 2
if b > 11:
return 1
return 0
"""
suite = Suite(code_py)
assert 3 == suite.run_code(20, 20).returned_value
assert 2 == suite.run_code(20, 10).returned_value
assert 1 == suite.run_code(10, 20).returned_value
assert 0 == suite.run_code(10, 10).returned_value

View File

@ -0,0 +1,27 @@
import pytest
from ..helpers import Suite
@pytest.mark.integration_test
def test_imported():
code_py = """
@imported
def helper(mul: i32) -> i32:
pass
@exported
def testEntry() -> i32:
return helper(2)
"""
def helper(mul: int) -> int:
return 4238 * mul
result = Suite(code_py).run_code(
runtime='wasmer',
imports={
'helper': helper,
}
)
assert 8476 == result.returned_value

View File

@ -0,0 +1,376 @@
import pytest
from phasm.exceptions import TypingError
from ..helpers import Suite
from ..constants import ALL_INT_TYPES, ALL_FLOAT_TYPES, COMPLETE_INT_TYPES, TYPE_MAP
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ALL_INT_TYPES)
def test_expr_constant_int(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 13
"""
result = Suite(code_py).run_code()
assert 13 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES)
def test_expr_constant_float(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 32.125
"""
result = Suite(code_py).run_code()
assert 32.125 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
def test_expr_constant_entanglement():
code_py = """
@exported
def testEntry() -> u8:
return 1000
"""
with pytest.raises(TypingError, match='u8.*1000'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ALL_INT_TYPES)
def test_module_constant_int(type_):
code_py = f"""
CONSTANT: {type_} = 13
@exported
def testEntry() -> {type_}:
return CONSTANT
"""
result = Suite(code_py).run_code()
assert 13 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES)
def test_module_constant_float(type_):
code_py = f"""
CONSTANT: {type_} = 32.125
@exported
def testEntry() -> {type_}:
return CONSTANT
"""
result = Suite(code_py).run_code()
assert 32.125 == result.returned_value
@pytest.mark.integration_test
def test_module_constant_entanglement():
code_py = """
CONSTANT: u8 = 1000
@exported
def testEntry() -> u32:
return 14
"""
with pytest.raises(TypingError, match='u8.*1000'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u32', 'u64']) # FIXME: Support u8, requires an extra AND operation
def test_logical_left_shift(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 << 3
"""
result = Suite(code_py).run_code()
assert 80 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u32', 'u64'])
def test_logical_right_shift_left_bit_zero(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 >> 3
"""
result = Suite(code_py).run_code()
assert 1 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
def test_logical_right_shift_left_bit_one():
code_py = """
@exported
def testEntry() -> u32:
return 4294967295 >> 16
"""
result = Suite(code_py).run_code()
assert 0xFFFF == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64'])
def test_bitwise_or(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 | 3
"""
result = Suite(code_py).run_code()
assert 11 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64'])
def test_bitwise_xor(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 ^ 3
"""
result = Suite(code_py).run_code()
assert 9 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64'])
def test_bitwise_and(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 & 3
"""
result = Suite(code_py).run_code()
assert 2 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES)
def test_addition_int(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 + 3
"""
result = Suite(code_py).run_code()
assert 13 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES)
def test_addition_float(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 32.0 + 0.125
"""
result = Suite(code_py).run_code()
assert 32.125 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES)
def test_subtraction_int(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 - 3
"""
result = Suite(code_py).run_code()
assert 7 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES)
def test_subtraction_float(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 100.0 - 67.875
"""
result = Suite(code_py).run_code()
assert 32.125 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
# TODO: Multiplication
# TODO: Division
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['f32', 'f64'])
def test_builtins_sqrt(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return sqrt(25.0)
"""
result = Suite(code_py).run_code()
assert 5 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', TYPE_MAP.keys())
def test_function_argument(type_):
code_py = f"""
@exported
def testEntry(a: {type_}) -> {type_}:
return a
"""
result = Suite(code_py).run_code(125)
assert 125 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.skip('TODO')
def test_explicit_positive_number():
code_py = """
@exported
def testEntry() -> i32:
return +523
"""
result = Suite(code_py).run_code()
assert 523 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.skip('TODO')
def test_explicit_negative_number():
code_py = """
@exported
def testEntry() -> i32:
return -19
"""
result = Suite(code_py).run_code()
assert -19 == result.returned_value
@pytest.mark.integration_test
def test_call_no_args():
code_py = """
def helper() -> i32:
return 19
@exported
def testEntry() -> i32:
return helper()
"""
result = Suite(code_py).run_code()
assert 19 == result.returned_value
@pytest.mark.integration_test
def test_call_pre_defined():
code_py = """
def helper(left: i32, right: i32) -> i32:
return left + right
@exported
def testEntry() -> i32:
return helper(10, 3)
"""
result = Suite(code_py).run_code()
assert 13 == result.returned_value
@pytest.mark.integration_test
def test_call_post_defined():
code_py = """
@exported
def testEntry() -> i32:
return helper(10, 3)
def helper(left: i32, right: i32) -> i32:
return left - right
"""
result = Suite(code_py).run_code()
assert 7 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES)
def test_call_with_expression_int(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return helper(10 + 20, 3 + 5)
def helper(left: {type_}, right: {type_}) -> {type_}:
return left - right
"""
result = Suite(code_py).run_code()
assert 22 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES)
def test_call_with_expression_float(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return helper(10.078125 + 90.046875, 63.0 + 5.0)
def helper(left: {type_}, right: {type_}) -> {type_}:
return left - right
"""
result = Suite(code_py).run_code()
assert 32.125 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
def test_call_invalid_type():
code_py = """
def helper() -> i64:
return 19
@exported
def testEntry() -> i32:
return helper()
"""
with pytest.raises(TypingError, match=r'i32.*i64'):
Suite(code_py).run_code()

View File

@ -0,0 +1,173 @@
import pytest
from phasm.exceptions import StaticError, TypingError
from ..constants import (
ALL_FLOAT_TYPES, ALL_INT_TYPES, COMPLETE_INT_TYPES, COMPLETE_PRIMITIVE_TYPES, TYPE_MAP
)
from ..helpers import Suite
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ALL_INT_TYPES)
def test_module_constant(type_):
code_py = f"""
CONSTANT: {type_}[3] = (24, 57, 80, )
@exported
def testEntry() -> {type_}:
return CONSTANT[0]
"""
result = Suite(code_py).run_code()
assert 24 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.skip('To decide: What to do on out of index?')
@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES)
def test_static_array_indexed(type_):
code_py = f"""
CONSTANT: {type_}[3] = (24, 57, 80, )
@exported
def testEntry() -> {type_}:
return helper(CONSTANT, 0, 1, 2)
def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}:
return array[i0] + array[i1] + array[i2]
"""
result = Suite(code_py).run_code()
assert 161 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES)
def test_function_call_int(type_):
code_py = f"""
CONSTANT: {type_}[3] = (24, 57, 80, )
@exported
def testEntry() -> {type_}:
return helper(CONSTANT)
def helper(array: {type_}[3]) -> {type_}:
return array[0] + array[1] + array[2]
"""
result = Suite(code_py).run_code()
assert 161 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES)
def test_function_call_float(type_):
code_py = f"""
CONSTANT: {type_}[3] = (24.0, 57.5, 80.75, )
@exported
def testEntry() -> {type_}:
return helper(CONSTANT)
def helper(array: {type_}[3]) -> {type_}:
return array[0] + array[1] + array[2]
"""
result = Suite(code_py).run_code()
assert 162.25 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
def test_module_constant_type_mismatch_bitwidth():
code_py = """
CONSTANT: u8[3] = (24, 57, 280, )
"""
with pytest.raises(TypingError, match='u8.*280'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_return_as_int():
code_py = """
CONSTANT: u8[3] = (24, 57, 80, )
def testEntry() -> u32:
return CONSTANT
"""
with pytest.raises(TypingError, match=r'u32.*u8\[3\]'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_module_constant_type_mismatch_not_subscriptable():
code_py = """
CONSTANT: u8 = 24
@exported
def testEntry() -> u8:
return CONSTANT[0]
"""
with pytest.raises(TypingError, match='Type cannot be subscripted:'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_module_constant_type_mismatch_index_out_of_range():
code_py = """
CONSTANT: u8[3] = (24, 57, 80, )
@exported
def testEntry() -> u8:
return CONSTANT[3]
"""
with pytest.raises(TypingError, match='Type cannot be subscripted with index 3:'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_static_array_constant_too_few_values():
code_py = """
CONSTANT: u8[3] = (24, 57, )
"""
with pytest.raises(TypingError, match='Member count does not match'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_static_array_constant_too_many_values():
code_py = """
CONSTANT: u8[3] = (24, 57, 1, 1, )
"""
with pytest.raises(TypingError, match='Member count does not match'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_static_array_constant_type_mismatch():
code_py = """
CONSTANT: u8[3] = (24, 4000, 1, )
"""
with pytest.raises(TypingError, match='u8.*4000'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.skip('To decide: What to do on out of index?')
def test_static_array_index_out_of_bounds():
code_py = """
CONSTANT0: u32[3] = (24, 57, 80, )
CONSTANT1: u32[16] = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, )
@exported
def testEntry() -> u32:
return CONSTANT0[16]
"""
result = Suite(code_py).run_code()
assert 0 == result.returned_value

View File

@ -1,18 +1,65 @@
import pytest import pytest
from phasm.parser import phasm_parse from ..helpers import Suite
from phasm.exceptions import StaticError
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) @pytest.mark.parametrize('type_', ('i32', 'f64', ))
def test_type_mismatch_function_argument(type_): def test_struct_0(type_):
code_py = f""" code_py = f"""
def helper(a: {type_}) -> (i32, i32, ): class CheckedValue:
return a value: {type_}
@exported
def testEntry() -> {type_}:
return helper(CheckedValue(23))
def helper(cv: CheckedValue) -> {type_}:
return cv.value
""" """
with pytest.raises(StaticError, match=f'Static error on line 3: Expected \\(i32, i32, \\), a is actually {type_}'): result = Suite(code_py).run_code()
phasm_parse(code_py)
assert 23 == result.returned_value
@pytest.mark.integration_test
def test_struct_1():
code_py = """
class Rectangle:
height: i32
width: i32
border: i32
@exported
def testEntry() -> i32:
return helper(Rectangle(100, 150, 2))
def helper(shape: Rectangle) -> i32:
return shape.height + shape.width + shape.border
"""
result = Suite(code_py).run_code()
assert 252 == result.returned_value
@pytest.mark.integration_test
def test_struct_2():
code_py = """
class Rectangle:
height: i32
width: i32
border: i32
@exported
def testEntry() -> i32:
return helper(Rectangle(100, 150, 2), Rectangle(200, 90, 3))
def helper(shape1: Rectangle, shape2: Rectangle) -> i32:
return shape1.height + shape1.width + shape1.border + shape2.height + shape2.width + shape2.border
"""
result = Suite(code_py).run_code()
assert 545 == result.returned_value
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) @pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64'])
@ -39,21 +86,6 @@ def testEntry(arg: ({type_}, )) -> (i32, i32, ):
with pytest.raises(StaticError, match=f'Static error on line 3: Expected \\(i32, i32, \\), arg\\[0\\] is actually {type_}'): with pytest.raises(StaticError, match=f'Static error on line 3: Expected \\(i32, i32, \\), arg\\[0\\] is actually {type_}'):
phasm_parse(code_py) phasm_parse(code_py)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64'])
def test_type_mismatch_function_result(type_):
code_py = f"""
def helper() -> {type_}:
return 1
@exported
def testEntry() -> (i32, i32, ):
return helper()
"""
with pytest.raises(StaticError, match=f'Static error on line 7: Expected \\(i32, i32, \\), helper actually returns {type_}'):
phasm_parse(code_py)
@pytest.mark.integration_test @pytest.mark.integration_test
def test_tuple_constant_too_few_values(): def test_tuple_constant_too_few_values():
code_py = """ code_py = """
@ -80,30 +112,3 @@ CONSTANT: (u32, u8, u8, ) = (24, 4000, 1, )
with pytest.raises(StaticError, match='Static error on line 2: Integer value out of range; expected 0..255, actual 4000'): with pytest.raises(StaticError, match='Static error on line 2: Integer value out of range; expected 0..255, actual 4000'):
phasm_parse(code_py) phasm_parse(code_py)
@pytest.mark.integration_test
def test_static_array_constant_too_few_values():
code_py = """
CONSTANT: u8[3] = (24, 57, )
"""
with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'):
phasm_parse(code_py)
@pytest.mark.integration_test
def test_static_array_constant_too_many_values():
code_py = """
CONSTANT: u8[3] = (24, 57, 1, 1, )
"""
with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'):
phasm_parse(code_py)
@pytest.mark.integration_test
def test_static_array_constant_type_mismatch():
code_py = """
CONSTANT: u8[3] = (24, 4000, 1, )
"""
with pytest.raises(StaticError, match='Static error on line 2: Integer value out of range; expected 0..255, actual 4000'):
phasm_parse(code_py)

View File

@ -1,20 +1,7 @@
import pytest import pytest
from .helpers import Suite from ..constants import COMPLETE_PRIMITIVE_TYPES, TYPE_MAP
from ..helpers import Suite
@pytest.mark.integration_test
def test_i32():
code_py = """
CONSTANT: i32 = 13
@exported
def testEntry() -> i32:
return CONSTANT * 5
"""
result = Suite(code_py).run_code()
assert 65 == result.returned_value
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64', ]) @pytest.mark.parametrize('type_', ['u8', 'u32', 'u64', ])
@ -52,36 +39,46 @@ def helper(vector: (u8, u8, u32, u32, u64, u64, )) -> u32:
assert 3333 == result.returned_value assert 3333 == result.returned_value
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64', ]) @pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES)
def test_static_array_1(type_): def test_tuple_simple_constructor(type_):
code_py = f""" code_py = f"""
CONSTANT: {type_}[1] = (65, )
@exported @exported
def testEntry() -> {type_}: def testEntry() -> {type_}:
return helper(CONSTANT) return helper((24, 57, 80, ))
def helper(vector: {type_}[1]) -> {type_}: def helper(vector: ({type_}, {type_}, {type_}, )) -> {type_}:
return vector[0] return vector[0] + vector[1] + vector[2]
""" """
result = Suite(code_py).run_code() result = Suite(code_py).run_code()
assert 65 == result.returned_value assert 161 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test @pytest.mark.integration_test
def test_static_array_6(): def test_tuple_float():
code_py = """ code_py = """
CONSTANT: u32[6] = (11, 22, 3333, 4444, 555555, 666666, )
@exported @exported
def testEntry() -> u32: def testEntry() -> f32:
return helper(CONSTANT) return helper((1.0, 2.0, 3.0, ))
def helper(vector: u32[6]) -> u32: def helper(v: (f32, f32, f32, )) -> f32:
return vector[2] return sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])
""" """
result = Suite(code_py).run_code() result = Suite(code_py).run_code()
assert 3333 == result.returned_value assert 3.74 < result.returned_value < 3.75
@pytest.mark.integration_test
@pytest.mark.skip('SIMD support is but a dream')
def test_tuple_i32x4():
code_py = """
@exported
def testEntry() -> i32x4:
return (51, 153, 204, 0, )
"""
result = Suite(code_py).run_code()
assert (1, 2, 3, 0) == result.returned_value

View File

@ -1,31 +0,0 @@
import pytest
from .helpers import Suite
@pytest.mark.integration_test
def test_bytes_index_out_of_bounds():
code_py = """
@exported
def testEntry(f: bytes) -> u8:
return f[50]
"""
result = Suite(code_py).run_code(b'Short', b'Long' * 100)
assert 0 == result.returned_value
@pytest.mark.integration_test
def test_static_array_index_out_of_bounds():
code_py = """
CONSTANT0: u32[3] = (24, 57, 80, )
CONSTANT1: u32[16] = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, )
@exported
def testEntry() -> u32:
return CONSTANT0[16]
"""
result = Suite(code_py).run_code()
assert 0 == result.returned_value

View File

@ -1,571 +0,0 @@
import pytest
from .helpers import Suite
TYPE_MAP = {
'u8': int,
'u32': int,
'u64': int,
'i32': int,
'i64': int,
'f32': float,
'f64': float,
}
COMPLETE_SIMPLE_TYPES = [
'u32', 'u64',
'i32', 'i64',
'f32', 'f64',
]
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', TYPE_MAP.keys())
def test_return(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 13
"""
result = Suite(code_py).run_code()
assert 13 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES)
def test_addition(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 + 3
"""
result = Suite(code_py).run_code()
assert 13 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES)
def test_subtraction(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 - 3
"""
result = Suite(code_py).run_code()
assert 7 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u32', 'u64']) # FIXME: Support u8, requires an extra AND operation
def test_logical_left_shift(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 << 3
"""
result = Suite(code_py).run_code()
assert 80 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64'])
def test_logical_right_shift(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 >> 3
"""
result = Suite(code_py).run_code()
assert 1 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64'])
def test_bitwise_or(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 | 3
"""
result = Suite(code_py).run_code()
assert 11 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64'])
def test_bitwise_xor(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 ^ 3
"""
result = Suite(code_py).run_code()
assert 9 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64'])
def test_bitwise_and(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 & 3
"""
result = Suite(code_py).run_code()
assert 2 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['f32', 'f64'])
def test_buildins_sqrt(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return sqrt(25)
"""
result = Suite(code_py).run_code()
assert 5 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', TYPE_MAP.keys())
def test_arg(type_):
code_py = f"""
@exported
def testEntry(a: {type_}) -> {type_}:
return a
"""
result = Suite(code_py).run_code(125)
assert 125 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.skip('Do we want it to work like this?')
def test_i32_to_i64():
code_py = """
@exported
def testEntry(a: i32) -> i64:
return a
"""
result = Suite(code_py).run_code(125)
assert 125 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.skip('Do we want it to work like this?')
def test_i32_plus_i64():
code_py = """
@exported
def testEntry(a: i32, b: i64) -> i64:
return a + b
"""
result = Suite(code_py).run_code(125, 100)
assert 225 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.skip('Do we want it to work like this?')
def test_f32_to_f64():
code_py = """
@exported
def testEntry(a: f32) -> f64:
return a
"""
result = Suite(code_py).run_code(125.5)
assert 125.5 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.skip('Do we want it to work like this?')
def test_f32_plus_f64():
code_py = """
@exported
def testEntry(a: f32, b: f64) -> f64:
return a + b
"""
result = Suite(code_py).run_code(125.5, 100.25)
assert 225.75 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.skip('TODO')
def test_uadd():
code_py = """
@exported
def testEntry() -> i32:
return +523
"""
result = Suite(code_py).run_code()
assert 523 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.skip('TODO')
def test_usub():
code_py = """
@exported
def testEntry() -> i32:
return -19
"""
result = Suite(code_py).run_code()
assert -19 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('inp', [9, 10, 11, 12])
def test_if_simple(inp):
code_py = """
@exported
def testEntry(a: i32) -> i32:
if a > 10:
return 15
return 3
"""
exp_result = 15 if inp > 10 else 3
suite = Suite(code_py)
result = suite.run_code(inp)
assert exp_result == result.returned_value
@pytest.mark.integration_test
@pytest.mark.skip('Such a return is not how things should be')
def test_if_complex():
code_py = """
@exported
def testEntry(a: i32) -> i32:
if a > 10:
return 10
elif a > 0:
return a
else:
return 0
return -1 # Required due to function type
"""
suite = Suite(code_py)
assert 10 == suite.run_code(20).returned_value
assert 10 == suite.run_code(10).returned_value
assert 8 == suite.run_code(8).returned_value
assert 0 == suite.run_code(0).returned_value
assert 0 == suite.run_code(-1).returned_value
@pytest.mark.integration_test
def test_if_nested():
code_py = """
@exported
def testEntry(a: i32, b: i32) -> i32:
if a > 11:
if b > 11:
return 3
return 2
if b > 11:
return 1
return 0
"""
suite = Suite(code_py)
assert 3 == suite.run_code(20, 20).returned_value
assert 2 == suite.run_code(20, 10).returned_value
assert 1 == suite.run_code(10, 20).returned_value
assert 0 == suite.run_code(10, 10).returned_value
@pytest.mark.integration_test
def test_call_pre_defined():
code_py = """
def helper(left: i32, right: i32) -> i32:
return left + right
@exported
def testEntry() -> i32:
return helper(10, 3)
"""
result = Suite(code_py).run_code()
assert 13 == result.returned_value
@pytest.mark.integration_test
def test_call_post_defined():
code_py = """
@exported
def testEntry() -> i32:
return helper(10, 3)
def helper(left: i32, right: i32) -> i32:
return left - right
"""
result = Suite(code_py).run_code()
assert 7 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES)
def test_call_with_expression(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return helper(10 + 20, 3 + 5)
def helper(left: {type_}, right: {type_}) -> {type_}:
return left - right
"""
result = Suite(code_py).run_code()
assert 22 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.skip('Not yet implemented')
def test_assign():
code_py = """
@exported
def testEntry() -> i32:
a: i32 = 8947
return a
"""
result = Suite(code_py).run_code()
assert 8947 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', TYPE_MAP.keys())
def test_struct_0(type_):
code_py = f"""
class CheckedValue:
value: {type_}
@exported
def testEntry() -> {type_}:
return helper(CheckedValue(23))
def helper(cv: CheckedValue) -> {type_}:
return cv.value
"""
result = Suite(code_py).run_code()
assert 23 == result.returned_value
@pytest.mark.integration_test
def test_struct_1():
code_py = """
class Rectangle:
height: i32
width: i32
border: i32
@exported
def testEntry() -> i32:
return helper(Rectangle(100, 150, 2))
def helper(shape: Rectangle) -> i32:
return shape.height + shape.width + shape.border
"""
result = Suite(code_py).run_code()
assert 252 == result.returned_value
@pytest.mark.integration_test
def test_struct_2():
code_py = """
class Rectangle:
height: i32
width: i32
border: i32
@exported
def testEntry() -> i32:
return helper(Rectangle(100, 150, 2), Rectangle(200, 90, 3))
def helper(shape1: Rectangle, shape2: Rectangle) -> i32:
return shape1.height + shape1.width + shape1.border + shape2.height + shape2.width + shape2.border
"""
result = Suite(code_py).run_code()
assert 545 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES)
def test_tuple_simple_constructor(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return helper((24, 57, 80, ))
def helper(vector: ({type_}, {type_}, {type_}, )) -> {type_}:
return vector[0] + vector[1] + vector[2]
"""
result = Suite(code_py).run_code()
assert 161 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
def test_tuple_float():
code_py = """
@exported
def testEntry() -> f32:
return helper((1.0, 2.0, 3.0, ))
def helper(v: (f32, f32, f32, )) -> f32:
return sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])
"""
result = Suite(code_py).run_code()
assert 3.74 < result.returned_value < 3.75
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES)
def test_static_array_module_constant(type_):
code_py = f"""
CONSTANT: {type_}[3] = (24, 57, 80, )
@exported
def testEntry() -> {type_}:
return helper(CONSTANT)
def helper(array: {type_}[3]) -> {type_}:
return array[0] + array[1] + array[2]
"""
result = Suite(code_py).run_code()
assert 161 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES)
def test_static_array_indexed(type_):
code_py = f"""
CONSTANT: {type_}[3] = (24, 57, 80, )
@exported
def testEntry() -> {type_}:
return helper(CONSTANT, 0, 1, 2)
def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}:
return array[i0] + array[i1] + array[i2]
"""
result = Suite(code_py).run_code()
assert 161 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
def test_bytes_address():
code_py = """
@exported
def testEntry(f: bytes) -> bytes:
return f
"""
result = Suite(code_py).run_code(b'This is a test')
# THIS DEPENDS ON THE ALLOCATOR
# A different allocator will return a different value
assert 20 == result.returned_value
@pytest.mark.integration_test
def test_bytes_length():
code_py = """
@exported
def testEntry(f: bytes) -> i32:
return len(f)
"""
result = Suite(code_py).run_code(b'This is another test')
assert 20 == result.returned_value
@pytest.mark.integration_test
def test_bytes_index():
code_py = """
@exported
def testEntry(f: bytes) -> u8:
return f[8]
"""
result = Suite(code_py).run_code(b'This is another test')
assert 0x61 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.skip('SIMD support is but a dream')
def test_tuple_i32x4():
code_py = """
@exported
def testEntry() -> i32x4:
return (51, 153, 204, 0, )
"""
result = Suite(code_py).run_code()
assert (1, 2, 3, 0) == result.returned_value
@pytest.mark.integration_test
def test_imported():
code_py = """
@imported
def helper(mul: i32) -> i32:
pass
@exported
def testEntry() -> i32:
return helper(2)
"""
def helper(mul: int) -> int:
return 4238 * mul
result = Suite(code_py).run_code(
runtime='wasmer',
imports={
'helper': helper,
}
)
assert 8476 == result.returned_value

View File

@ -2,8 +2,8 @@ import sys
import pytest import pytest
from .helpers import write_header from ..helpers import write_header
from .runners import RunnerPywasm3 as Runner from ..runners import RunnerPywasm3 as Runner
def setup_interpreter(phash_code: str) -> Runner: def setup_interpreter(phash_code: str) -> Runner:
runner = Runner(phash_code) runner = Runner(phash_code)