From 205897101ff43a5c1e2a6285c4d93e735c2b01a0 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Fri, 16 Sep 2022 14:42:40 +0200 Subject: [PATCH] Adds a typing system to Phasm --- Makefile | 5 +- README.md | 4 +- TODO.md | 3 +- examples/buffer.py | 2 +- phasm/__main__.py | 2 + phasm/codestyle.py | 144 ++--- phasm/compiler.py | 527 ++++++++++------ phasm/ourlang.py | 343 ++++------- phasm/parser.py | 520 +++++----------- phasm/stdlib/alloc.py | 2 +- phasm/type3/__init__.py | 0 phasm/type3/constraints.py | 549 +++++++++++++++++ phasm/type3/constraintsgenerator.py | 219 +++++++ phasm/type3/entry.py | 154 +++++ phasm/type3/types.py | 322 ++++++++++ phasm/typing.py | 202 ------- phasm/wasmeasy.py | 69 --- pylintrc | 4 +- requirements.txt | 8 +- tests/integration/constants.py | 16 + tests/integration/helpers.py | 4 +- tests/integration/runners.py | 4 +- tests/integration/test_constants.py | 87 --- tests/integration/test_examples/__init__.py | 0 .../integration/test_examples/test_buffer.py | 19 + .../test_crc32.py} | 4 +- tests/integration/test_examples/test_fib.py | 12 + tests/integration/test_fib.py | 30 - tests/integration/test_helper.py | 70 --- tests/integration/test_lang/__init__.py | 0 .../{ => test_lang}/test_builtins.py | 4 +- tests/integration/test_lang/test_bytes.py | 84 +++ tests/integration/test_lang/test_if.py | 71 +++ tests/integration/test_lang/test_interface.py | 79 +++ .../integration/test_lang/test_primitives.py | 486 +++++++++++++++ .../test_lang/test_static_array.py | 195 ++++++ tests/integration/test_lang/test_struct.py | 154 +++++ tests/integration/test_lang/test_tuple.py | 188 ++++++ tests/integration/test_runtime_checks.py | 31 - tests/integration/test_simple.py | 571 ------------------ tests/integration/test_static_checking.py | 109 ---- tests/integration/test_stdlib/__init__.py | 0 .../test_alloc.py} | 4 +- 43 files changed, 3249 insertions(+), 2052 deletions(-) create mode 100644 phasm/type3/__init__.py create mode 100644 phasm/type3/constraints.py create mode 100644 phasm/type3/constraintsgenerator.py create mode 100644 phasm/type3/entry.py create mode 100644 phasm/type3/types.py delete mode 100644 phasm/typing.py delete mode 100644 phasm/wasmeasy.py create mode 100644 tests/integration/constants.py delete mode 100644 tests/integration/test_constants.py create mode 100644 tests/integration/test_examples/__init__.py create mode 100644 tests/integration/test_examples/test_buffer.py rename tests/integration/{test_examples.py => test_examples/test_crc32.py} (93%) create mode 100644 tests/integration/test_examples/test_fib.py delete mode 100644 tests/integration/test_fib.py delete mode 100644 tests/integration/test_helper.py create mode 100644 tests/integration/test_lang/__init__.py rename tests/integration/{ => test_lang}/test_builtins.py (95%) create mode 100644 tests/integration/test_lang/test_bytes.py create mode 100644 tests/integration/test_lang/test_if.py create mode 100644 tests/integration/test_lang/test_interface.py create mode 100644 tests/integration/test_lang/test_primitives.py create mode 100644 tests/integration/test_lang/test_static_array.py create mode 100644 tests/integration/test_lang/test_struct.py create mode 100644 tests/integration/test_lang/test_tuple.py delete mode 100644 tests/integration/test_runtime_checks.py delete mode 100644 tests/integration/test_simple.py delete mode 100644 tests/integration/test_static_checking.py create mode 100644 tests/integration/test_stdlib/__init__.py rename tests/integration/{test_stdlib_alloc.py => test_stdlib/test_alloc.py} (95%) diff --git a/Makefile b/Makefile index ad186c2..ac1bf09 100644 --- a/Makefile +++ b/Makefile @@ -34,11 +34,14 @@ typecheck: venv/.done venv/bin/mypy --strict phasm tests/integration/runners.py venv/.done: requirements.txt - python3.8 -m venv venv + python3.10 -m venv venv venv/bin/python3 -m pip install wheel pip --upgrade venv/bin/python3 -m pip install -r $^ touch $@ +clean-examples: + rm -f examples/*.wat examples/*.wasm examples/*.wat.html examples/*.py.html + .SECONDARY: # Keep intermediate files .PHONY: examples diff --git a/README.md b/README.md index c9cb056..a5af8d8 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ This is a hobby project for now. Use at your own risk. How to run ---------- -You should only need make and python3. Currently, we're working with python3.8, +You should only need make and python3. Currently, we're working with python3.10, since we're using the python ast parser, it might not work on other versions. To run the examples: @@ -32,7 +32,7 @@ make lint typecheck To compile a Phasm file: ```sh -python3.8 -m phasm source.py output.wat +python3.10 -m phasm source.py output.wat ``` Additional required tools diff --git a/TODO.md b/TODO.md index 0da4619..212fbd0 100644 --- a/TODO.md +++ b/TODO.md @@ -1,7 +1,8 @@ # TODO +- Rename constant to literal + - Implement a trace() builtin for debugging -- Implement a proper type matching / checking system - Check if we can use DataView in the Javascript examples, e.g. with setUint32 - Storing u8 in memory still claims 32 bits (since that's what you need in local variables). However, using load8_u / loadu_s we can optimize this. - Implement a FizzBuzz example diff --git a/examples/buffer.py b/examples/buffer.py index 6f251ff..8fe6570 100644 --- a/examples/buffer.py +++ b/examples/buffer.py @@ -3,5 +3,5 @@ def index(inp: bytes, idx: u32) -> u8: return inp[idx] @exported -def length(inp: bytes) -> i32: +def length(inp: bytes) -> u32: return len(inp) diff --git a/phasm/__main__.py b/phasm/__main__.py index bc94cdb..97615e0 100644 --- a/phasm/__main__.py +++ b/phasm/__main__.py @@ -5,6 +5,7 @@ Functions for using this module from CLI import sys from .parser import phasm_parse +from .type3.entry import phasm_type3 from .compiler import phasm_compile def main(source: str, sink: str) -> int: @@ -16,6 +17,7 @@ def main(source: str, sink: str) -> int: code_py = fil.read() our_module = phasm_parse(code_py) + phasm_type3(our_module, verbose=False) wasm_module = phasm_compile(our_module) code_wat = wasm_module.to_wat() diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 4b38e32..c984937 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -6,7 +6,8 @@ It's intented to be a "any color, as long as it's black" kind of renderer from typing import Generator from . import ourlang -from . import typing +from .type3 import types as type3types +from .type3.types import TYPE3_ASSERTION_ERROR, Type3, Type3OrPlaceholder def phasm_render(inp: ourlang.Module) -> str: """ @@ -16,63 +17,39 @@ def phasm_render(inp: ourlang.Module) -> str: Statements = Generator[str, None, None] -def type_(inp: typing.TypeBase) -> str: +def type3(inp: Type3OrPlaceholder) -> str: """ - Render: Type (name) + Render: type's name """ - if isinstance(inp, typing.TypeNone): + assert isinstance(inp, Type3), TYPE3_ASSERTION_ERROR + + if inp is type3types.none: return 'None' - if isinstance(inp, typing.TypeBool): - return 'bool' + if isinstance(inp, type3types.AppliedType3): + if inp.base == type3types.tuple: + return '(' + ', '.join( + type3(x) + for x in inp.args + if isinstance(x, Type3) # Skip ints, not allowed here anyhow + ) + ', )' - if isinstance(inp, typing.TypeUInt8): - return 'u8' + if inp.base == type3types.static_array: + assert 2 == len(inp.args) + assert isinstance(inp.args[0], Type3), TYPE3_ASSERTION_ERROR + assert isinstance(inp.args[1], type3types.IntType3), TYPE3_ASSERTION_ERROR - if isinstance(inp, typing.TypeUInt32): - return 'u32' + return inp.args[0].name + '[' + inp.args[1].name + ']' - if isinstance(inp, typing.TypeUInt64): - return 'u64' + return inp.name - if isinstance(inp, typing.TypeInt32): - return 'i32' - - if isinstance(inp, typing.TypeInt64): - return 'i64' - - if isinstance(inp, typing.TypeFloat32): - return 'f32' - - if isinstance(inp, typing.TypeFloat64): - return 'f64' - - if isinstance(inp, typing.TypeBytes): - return 'bytes' - - if isinstance(inp, typing.TypeTuple): - mems = ', '.join( - type_(x.type) - for x in inp.members - ) - - return f'({mems}, )' - - if isinstance(inp, typing.TypeStaticArray): - return f'{type_(inp.member_type)}[{len(inp.members)}]' - - if isinstance(inp, typing.TypeStruct): - return inp.name - - raise NotImplementedError(type_, inp) - -def struct_definition(inp: typing.TypeStruct) -> str: +def struct_definition(inp: ourlang.StructDefinition) -> str: """ Render: TypeStruct's definition """ - result = f'class {inp.name}:\n' - for mem in inp.members: - result += f' {mem.name}: {type_(mem.type)}\n' + result = f'class {inp.struct_type3.name}:\n' + for mem, typ in inp.struct_type3.members.items(): + result += f' {mem}: {type3(typ)}\n' return result @@ -80,40 +57,44 @@ def constant_definition(inp: ourlang.ModuleConstantDef) -> str: """ Render: Module Constant's definition """ - return f'{inp.name}: {type_(inp.type)} = {expression(inp.constant)}\n' + return f'{inp.name}: {type3(inp.type3)} = {expression(inp.constant)}\n' def expression(inp: ourlang.Expression) -> str: """ Render: A Phasm expression """ - if isinstance(inp, ( - ourlang.ConstantUInt8, ourlang.ConstantUInt32, ourlang.ConstantUInt64, - ourlang.ConstantInt32, ourlang.ConstantInt64, - )): - return str(inp.value) - - if isinstance(inp, (ourlang.ConstantFloat32, ourlang.ConstantFloat64, )): - # These might not round trip if the original constant + if isinstance(inp, ourlang.ConstantPrimitive): + # Floats might not round trip if the original constant # could not fit in the given float type return str(inp.value) - if isinstance(inp, (ourlang.ConstantTuple, ourlang.ConstantStaticArray, )): + if isinstance(inp, ourlang.ConstantTuple): return '(' + ', '.join( expression(x) for x in inp.value ) + ', )' + if isinstance(inp, ourlang.ConstantStruct): + return inp.struct_name + '(' + ', '.join( + expression(x) + for x in inp.value + ) + ')' + if isinstance(inp, ourlang.VariableReference): - return str(inp.name) + return str(inp.variable.name) if isinstance(inp, ourlang.UnaryOp): if ( - inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS - or inp.operator in ourlang.WEBASSEMBLY_BUILDIN_BYTES_OPS): + inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS + or inp.operator in ourlang.WEBASSEMBLY_BUILTIN_BYTES_OPS): return f'{inp.operator}({expression(inp.right)})' if inp.operator == 'cast': - return f'{type_(inp.type)}({expression(inp.right)})' + mtyp = type3(inp.type3) + if mtyp is None: + raise NotImplementedError(f'Casting to type {inp.type_var}') + + return f'{mtyp}({expression(inp.right)})' return f'{inp.operator}{expression(inp.right)}' @@ -127,32 +108,31 @@ def expression(inp: ourlang.Expression) -> str: ) if isinstance(inp.function, ourlang.StructConstructor): - return f'{inp.function.struct.name}({args})' - - if isinstance(inp.function, ourlang.TupleConstructor): - return f'({args}, )' + return f'{inp.function.struct_type3.name}({args})' return f'{inp.function.name}({args})' - if isinstance(inp, ourlang.AccessBytesIndex): - return f'{expression(inp.varref)}[{expression(inp.index)}]' + if isinstance(inp, ourlang.TupleInstantiation): + args = ', '.join( + expression(arg) + for arg in inp.elements + ) + + return f'({args}, )' + + if isinstance(inp, ourlang.Subscript): + varref = expression(inp.varref) + index = expression(inp.index) + + return f'{varref}[{index}]' if isinstance(inp, ourlang.AccessStructMember): - return f'{expression(inp.varref)}.{inp.member.name}' - - if isinstance(inp, (ourlang.AccessTupleMember, ourlang.AccessStaticArrayMember, )): - if isinstance(inp.member, ourlang.Expression): - return f'{expression(inp.varref)}[{expression(inp.member)}]' - - return f'{expression(inp.varref)}[{inp.member.idx}]' + return f'{expression(inp.varref)}.{inp.member}' if isinstance(inp, ourlang.Fold): fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr' return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})' - if isinstance(inp, ourlang.ModuleConstantReference): - return inp.definition.name - raise NotImplementedError(expression, inp) def statement(inp: ourlang.Statement) -> Statements: @@ -193,11 +173,11 @@ def function(inp: ourlang.Function) -> str: result += '@imported\n' args = ', '.join( - f'{x}: {type_(y)}' - for x, y in inp.posonlyargs + f'{p.name}: {type3(p.type3)}' + for p in inp.posonlyargs ) - result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n' + result += f'def {inp.name}({args}) -> {type3(inp.returns_type3)}:\n' if inp.imported: result += ' pass\n' @@ -215,7 +195,7 @@ def module(inp: ourlang.Module) -> str: """ result = '' - for struct in inp.structs.values(): + for struct in inp.struct_definitions.values(): if result: result += '\n' result += struct_definition(struct) @@ -227,7 +207,7 @@ def module(inp: ourlang.Module) -> str: for func in inp.functions.values(): if func.lineno < 0: - # Buildin (-2) or auto generated (-1) + # Builtin (-2) or auto generated (-1) continue if result: diff --git a/phasm/compiler.py b/phasm/compiler.py index 4ba5d5a..c2bbad1 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -1,11 +1,13 @@ """ This module contains the code to convert parsed Ourlang into WebAssembly code """ +from typing import List, Union + import struct from . import codestyle from . import ourlang -from . import typing +from .type3 import types as type3types from . import wasm from .stdlib import alloc as stdlib_alloc @@ -13,17 +15,14 @@ from .stdlib import types as stdlib_types from .wasmgenerator import Generator as WasmGenerator LOAD_STORE_TYPE_MAP = { - typing.TypeUInt8: 'i32', - typing.TypeUInt32: 'i32', - typing.TypeUInt64: 'i64', - typing.TypeInt32: 'i32', - typing.TypeInt64: 'i64', - typing.TypeFloat32: 'f32', - typing.TypeFloat64: 'f64', + 'u8': 'i32', # Have to use an u32, since there is no native u8 type + 'i32': 'i32', + 'i64': 'i64', + 'u32': 'i32', + 'u64': 'i64', + 'f32': 'f32', + 'f64': 'f64', } -""" -When generating code, we sometimes need to load or store simple values -""" def phasm_compile(inp: ourlang.Module) -> wasm.Module: """ @@ -32,42 +31,60 @@ def phasm_compile(inp: ourlang.Module) -> wasm.Module: """ return module(inp) -def type_(inp: typing.TypeBase) -> wasm.WasmType: +def type3(inp: type3types.Type3OrPlaceholder) -> wasm.WasmType: """ Compile: type + + Types are used for example in WebAssembly function parameters + and return types. """ - if isinstance(inp, typing.TypeNone): + assert isinstance(inp, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR + + if inp == type3types.none: return wasm.WasmTypeNone() - if isinstance(inp, typing.TypeUInt8): + if inp == type3types.u8: # WebAssembly has only support for 32 and 64 bits # So we need to store more memory per byte return wasm.WasmTypeInt32() - if isinstance(inp, typing.TypeUInt32): + if inp == type3types.u32: return wasm.WasmTypeInt32() - if isinstance(inp, typing.TypeUInt64): + if inp == type3types.u64: return wasm.WasmTypeInt64() - if isinstance(inp, typing.TypeInt32): + if inp == type3types.i32: return wasm.WasmTypeInt32() - if isinstance(inp, typing.TypeInt64): + if inp == type3types.i64: return wasm.WasmTypeInt64() - if isinstance(inp, typing.TypeFloat32): + if inp == type3types.f32: return wasm.WasmTypeFloat32() - if isinstance(inp, typing.TypeFloat64): + if inp == type3types.f64: return wasm.WasmTypeFloat64() - if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)): - # Structs and tuples are passed as pointer + if inp == type3types.bytes: + # bytes are passed as pointer # And pointers are i32 return wasm.WasmTypeInt32() - raise NotImplementedError(type_, inp) + if isinstance(inp, type3types.StructType3): + # Structs are passed as pointer, which are i32 + return wasm.WasmTypeInt32() + + if isinstance(inp, type3types.AppliedType3): + if inp.base == type3types.static_array: + # Static Arrays are passed as pointer, which are i32 + return wasm.WasmTypeInt32() + + if inp.base == type3types.tuple: + # Tuples are passed as pointer, which are i32 + return wasm.WasmTypeInt32() + + raise NotImplementedError(type3, inp) # Operators that work for i32, i64, f32, f64 OPERATOR_MAP = { @@ -81,8 +98,6 @@ U8_OPERATOR_MAP = { # Under the hood, this is an i32 # Implementing Right Shift XOR, OR, AND is fine since the 3 remaining # bytes stay zero after this operation - # Since it's unsigned an unsigned value, Logical or Arithmetic shift right - # are the same operation '>>': 'shr_u', '^': 'xor', '|': 'or', @@ -99,6 +114,7 @@ U32_OPERATOR_MAP = { '^': 'xor', '|': 'or', '&': 'and', + '/': 'div_u' # Division by zero is a trap and the program will panic } U64_OPERATOR_MAP = { @@ -111,6 +127,7 @@ U64_OPERATOR_MAP = { '^': 'xor', '|': 'or', '&': 'and', + '/': 'div_u' # Division by zero is a trap and the program will panic } I32_OPERATOR_MAP = { @@ -118,6 +135,7 @@ I32_OPERATOR_MAP = { '>': 'gt_s', '<=': 'le_s', '>=': 'ge_s', + '/': 'div_s' # Division by zero is a trap and the program will panic } I64_OPERATOR_MAP = { @@ -125,115 +143,226 @@ I64_OPERATOR_MAP = { '>': 'gt_s', '<=': 'le_s', '>=': 'ge_s', + '/': 'div_s' # Division by zero is a trap and the program will panic } +F32_OPERATOR_MAP = { + '/': 'div' # Division by zero is a trap and the program will panic +} + +F64_OPERATOR_MAP = { + '/': 'div' # Division by zero is a trap and the program will panic +} + +def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) -> None: + """ + Compile: Instantiation (allocation) of a tuple + """ + assert isinstance(inp.type3, type3types.AppliedType3) + assert inp.type3.base is type3types.tuple + assert len(inp.elements) == len(inp.type3.args) + + comment_elements = '' + for element in inp.elements: + assert isinstance(element.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR + comment_elements += f'{element.type3.name}, ' + + tmp_var = wgn.temp_var_i32('tuple_adr') + wgn.add_statement('nop', comment=f'{tmp_var.name} := ({comment_elements})') + + # Allocated the required amounts of bytes in memory + wgn.i32.const(_calculate_alloc_size(inp.type3)) + wgn.call(stdlib_alloc.__alloc__) + wgn.local.set(tmp_var) + + # Store each element individually + offset = 0 + for element, exp_type3 in zip(inp.elements, inp.type3.args): + if isinstance(exp_type3, type3types.PlaceholderForType): + assert exp_type3.resolve_as is not None + exp_type3 = exp_type3.resolve_as + + assert element.type3 == exp_type3 + + assert isinstance(exp_type3, type3types.PrimitiveType3), NotImplementedError('Tuple of applied types / structs') + mtyp = LOAD_STORE_TYPE_MAP[exp_type3.name] + + wgn.local.get(tmp_var) + expression(wgn, element) + wgn.add_statement(f'{mtyp}.store', 'offset=' + str(offset)) + + offset += _calculate_alloc_size(exp_type3) + + # Return the allocated address + wgn.local.get(tmp_var) + def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: """ Compile: Any expression """ - if isinstance(inp, ourlang.ConstantUInt8): - wgn.i32.const(inp.value) - return + if isinstance(inp, ourlang.ConstantPrimitive): + assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR - if isinstance(inp, ourlang.ConstantUInt32): - wgn.i32.const(inp.value) - return + if inp.type3 == type3types.u8: + # No native u8 type - treat as i32, with caution + assert isinstance(inp.value, int) + wgn.i32.const(inp.value) + return - if isinstance(inp, ourlang.ConstantUInt64): - wgn.i64.const(inp.value) - return + if inp.type3 in (type3types.i32, type3types.u32, ): + assert isinstance(inp.value, int) + wgn.i32.const(inp.value) + return - if isinstance(inp, ourlang.ConstantInt32): - wgn.i32.const(inp.value) - return + if inp.type3 in (type3types.i64, type3types.u64, ): + assert isinstance(inp.value, int) + wgn.i64.const(inp.value) + return - if isinstance(inp, ourlang.ConstantInt64): - wgn.i64.const(inp.value) - return + if inp.type3 == type3types.f32: + assert isinstance(inp.value, float) + wgn.f32.const(inp.value) + return - if isinstance(inp, ourlang.ConstantFloat32): - wgn.f32.const(inp.value) - return + if inp.type3 == type3types.f64: + assert isinstance(inp.value, float) + wgn.f64.const(inp.value) + return - if isinstance(inp, ourlang.ConstantFloat64): - wgn.f64.const(inp.value) - return + raise NotImplementedError(f'Constants with type {inp.type3}') if isinstance(inp, ourlang.VariableReference): - wgn.add_statement('local.get', '${}'.format(inp.name)) - return + if isinstance(inp.variable, ourlang.FunctionParam): + wgn.add_statement('local.get', '${}'.format(inp.variable.name)) + return + + if isinstance(inp.variable, ourlang.ModuleConstantDef): + assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR + + if isinstance(inp.type3, type3types.StructType3): + assert inp.variable.data_block is not None, 'Structs must be memory stored' + assert inp.variable.data_block.address is not None, 'Value not allocated' + wgn.i32.const(inp.variable.data_block.address) + return + + if isinstance(inp.type3, type3types.AppliedType3): + if inp.type3.base == type3types.static_array: + assert inp.variable.data_block is not None, 'Static arrays must be memory stored' + assert inp.variable.data_block.address is not None, 'Value not allocated' + wgn.i32.const(inp.variable.data_block.address) + return + + if inp.type3.base == type3types.tuple: + assert inp.variable.data_block is not None, 'Tuples must be memory stored' + assert inp.variable.data_block.address is not None, 'Value not allocated' + wgn.i32.const(inp.variable.data_block.address) + return + + raise NotImplementedError(expression, inp.variable, inp.type3.base) + + assert inp.variable.data_block is None, 'Primitives are not memory stored' + + expression(wgn, inp.variable.constant) + return + + raise NotImplementedError(expression, inp.variable) if isinstance(inp, ourlang.BinaryOp): expression(wgn, inp.left) expression(wgn, inp.right) - if isinstance(inp.type, typing.TypeUInt8): + assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR + # FIXME: Re-implement build-in operators + # Maybe operator_annotation is the way to go + # Maybe the older stuff below that is the way to go + + operator_annotation = f'({inp.operator}) :: {inp.left.type3:s} -> {inp.right.type3:s} -> {inp.type3:s}' + if operator_annotation == '(>) :: i32 -> i32 -> bool': + wgn.add_statement('i32.gt_s') + return + + if operator_annotation == '(<) :: u64 -> u64 -> bool': + wgn.add_statement('i64.lt_u') + return + if operator_annotation == '(==) :: u64 -> u64 -> bool': + wgn.add_statement('i64.eq') + return + + if inp.type3 == type3types.u8: if operator := U8_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if isinstance(inp.type, typing.TypeUInt32): + if inp.type3 == type3types.u32: if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return if operator := U32_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if isinstance(inp.type, typing.TypeUInt64): + if inp.type3 == type3types.u64: if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return if operator := U64_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return - if isinstance(inp.type, typing.TypeInt32): + if inp.type3 == type3types.i32: if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return if operator := I32_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if isinstance(inp.type, typing.TypeInt64): + if inp.type3 == type3types.i64: if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return if operator := I64_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return - if isinstance(inp.type, typing.TypeFloat32): + if inp.type3 == type3types.f32: if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'f32.{operator}') return - if isinstance(inp.type, typing.TypeFloat64): + if operator := F32_OPERATOR_MAP.get(inp.operator, None): + wgn.add_statement(f'f32.{operator}') + return + if inp.type3 == type3types.f64: if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'f64.{operator}') return + if operator := F64_OPERATOR_MAP.get(inp.operator, None): + wgn.add_statement(f'f64.{operator}') + return - raise NotImplementedError(expression, inp.type, inp.operator) + raise NotImplementedError(expression, inp.operator, inp.left.type3, inp.right.type3, inp.type3) if isinstance(inp, ourlang.UnaryOp): expression(wgn, inp.right) - if isinstance(inp.type, typing.TypeFloat32): - if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: + assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR + + if inp.type3 == type3types.f32: + if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS: wgn.add_statement(f'f32.{inp.operator}') return - if isinstance(inp.type, typing.TypeFloat64): - if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: + if inp.type3 == type3types.f64: + if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS: wgn.add_statement(f'f64.{inp.operator}') return - if isinstance(inp.type, typing.TypeInt32): + if inp.type3 == type3types.u32: if inp.operator == 'len': - if isinstance(inp.right.type, typing.TypeBytes): + if inp.right.type3 == type3types.bytes: wgn.i32.load() return if inp.operator == 'cast': - if isinstance(inp.type, typing.TypeUInt32) and isinstance(inp.right.type, typing.TypeUInt8): + if inp.type3 == type3types.u32 and inp.right.type3 == type3types.u8: # Nothing to do, you can use an u8 value as a u32 no problem return - raise NotImplementedError(expression, inp.type, inp.operator) + raise NotImplementedError(expression, inp.type3, inp.operator) if isinstance(inp, ourlang.FunctionCall): for arg in inp.arguments: @@ -242,104 +371,103 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: wgn.add_statement('call', '${}'.format(inp.function.name)) return - if isinstance(inp, ourlang.AccessBytesIndex): - if not isinstance(inp.type, typing.TypeUInt8): - raise NotImplementedError(inp, inp.type) - - expression(wgn, inp.varref) - expression(wgn, inp.index) - wgn.call(stdlib_types.__subscript_bytes__) + if isinstance(inp, ourlang.TupleInstantiation): + tuple_instantiation(wgn, inp) return - if isinstance(inp, ourlang.AccessStructMember): - mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, inp.member) + if isinstance(inp, ourlang.Subscript): + assert isinstance(inp.varref.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR - expression(wgn, inp.varref) - wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) - return - - if isinstance(inp, ourlang.AccessTupleMember): - mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, inp.member) - - expression(wgn, inp.varref) - wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) - return - - if isinstance(inp, ourlang.AccessStaticArrayMember): - mtyp = LOAD_STORE_TYPE_MAP.get(inp.static_array.member_type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of static arrays - raise NotImplementedError(expression, inp, inp.member) - - if isinstance(inp.member, typing.TypeStaticArrayMember): + if inp.varref.type3 is type3types.bytes: expression(wgn, inp.varref) - wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) + expression(wgn, inp.index) + wgn.call(stdlib_types.__subscript_bytes__) return + if isinstance(inp.varref.type3, type3types.AppliedType3): + if inp.varref.type3.base == type3types.static_array: + assert 2 == len(inp.varref.type3.args) + el_type = inp.varref.type3.args[0] + assert isinstance(el_type, type3types.Type3) + el_len = inp.varref.type3.args[1] + assert isinstance(el_len, type3types.IntType3) + + # OPTIMIZE: If index is a constant, we can use offset instead of multiply + # and we don't need to do the out of bounds check + + expression(wgn, inp.varref) + + tmp_var = wgn.temp_var_i32('index') + expression(wgn, inp.index) + wgn.local.tee(tmp_var) + + # Out of bounds check based on el_len.value + wgn.i32.const(el_len.value) + wgn.i32.ge_u() + with wgn.if_(): + wgn.unreachable(comment='Out of bounds') + + wgn.local.get(tmp_var) + wgn.i32.const(_calculate_alloc_size(el_type)) + wgn.i32.mul() + wgn.i32.add() + + assert isinstance(el_type, type3types.PrimitiveType3), NotImplementedError('Tuple of applied types / structs') + mtyp = LOAD_STORE_TYPE_MAP[el_type.name] + + wgn.add_statement(f'{mtyp}.load') + return + + if inp.varref.type3.base == type3types.tuple: + assert isinstance(inp.index, ourlang.ConstantPrimitive) + assert isinstance(inp.index.value, int) + + offset = 0 + for el_type in inp.varref.type3.args[0:inp.index.value]: + assert isinstance(el_type, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR + offset += _calculate_alloc_size(el_type) + + # This doubles as the out of bounds check + el_type = inp.varref.type3.args[inp.index.value] + assert isinstance(el_type, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR + + expression(wgn, inp.varref) + + assert isinstance(el_type, type3types.PrimitiveType3), NotImplementedError('Tuple of applied types / structs') + mtyp = LOAD_STORE_TYPE_MAP[el_type.name] + + wgn.add_statement(f'{mtyp}.load', f'offset={offset}') + return + + raise NotImplementedError(expression, inp, inp.varref.type3) + + if isinstance(inp, ourlang.AccessStructMember): + assert isinstance(inp.struct_type3.members[inp.member], type3types.PrimitiveType3), NotImplementedError('Tuple of applied types / structs') + mtyp = LOAD_STORE_TYPE_MAP[inp.struct_type3.members[inp.member].name] + expression(wgn, inp.varref) - expression(wgn, inp.member) - wgn.i32.const(inp.static_array.member_type.alloc_size()) - wgn.i32.mul() - wgn.i32.add() - wgn.add_statement(f'{mtyp}.load') + wgn.add_statement(f'{mtyp}.load', 'offset=' + str(_calculate_member_offset( + inp.struct_type3, inp.member + ))) return if isinstance(inp, ourlang.Fold): expression_fold(wgn, inp) return - if isinstance(inp, ourlang.ModuleConstantReference): - if isinstance(inp.type, typing.TypeTuple): - assert isinstance(inp.definition.constant, ourlang.ConstantTuple) - assert inp.definition.data_block is not None, 'Combined values are memory stored' - assert inp.definition.data_block.address is not None, 'Value not allocated' - wgn.i32.const(inp.definition.data_block.address) - return - - if isinstance(inp.type, typing.TypeStaticArray): - assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray) - assert inp.definition.data_block is not None, 'Combined values are memory stored' - assert inp.definition.data_block.address is not None, 'Value not allocated' - wgn.i32.const(inp.definition.data_block.address) - return - - assert inp.definition.data_block is None, 'Primitives are not memory stored' - - mtyp = LOAD_STORE_TYPE_MAP.get(inp.type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, inp.type) - - expression(wgn, inp.definition.constant) - return - raise NotImplementedError(expression, inp) def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: """ Compile: Fold expression """ - mtyp = LOAD_STORE_TYPE_MAP.get(inp.base.type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, inp.base) + assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR - if inp.iter.type.__class__.__name__ != 'TypeBytes': - raise NotImplementedError(expression, inp, inp.iter.type) + if inp.iter.type3 is not type3types.bytes: + raise NotImplementedError(expression_fold, inp, inp.iter.type3) wgn.add_statement('nop', comment='acu :: u8') - acu_var = wgn.temp_var_u8(f'fold_{codestyle.type_(inp.type)}_acu') + acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu') wgn.add_statement('nop', comment='adr :: bytes*') adr_var = wgn.temp_var_i32('fold_i32_adr') wgn.add_statement('nop', comment='len :: i32') @@ -359,7 +487,7 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: wgn.local.set(len_var) wgn.add_statement('nop', comment='i = 0') - idx_var = wgn.temp_var_i32(f'fold_{codestyle.type_(inp.type)}_idx') + idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx') wgn.i32.const(0) wgn.local.set(idx_var) @@ -450,7 +578,7 @@ def function_argument(inp: ourlang.FunctionParam) -> wasm.Param: """ Compile: function argument """ - return (inp[0], type_(inp[1]), ) + return (inp.name, type3(inp.type3), ) def import_(inp: ourlang.Function) -> wasm.Import: """ @@ -466,7 +594,7 @@ def import_(inp: ourlang.Function) -> wasm.Import: function_argument(x) for x in inp.posonlyargs ], - type_(inp.returns) + type3(inp.returns_type3) ) def function(inp: ourlang.Function) -> wasm.Function: @@ -477,9 +605,7 @@ def function(inp: ourlang.Function) -> wasm.Function: wgn = WasmGenerator() - if isinstance(inp, ourlang.TupleConstructor): - _generate_tuple_constructor(wgn, inp) - elif isinstance(inp, ourlang.StructConstructor): + if isinstance(inp, ourlang.StructConstructor): _generate_struct_constructor(wgn, inp) else: for stat in inp.statements: @@ -496,7 +622,7 @@ def function(inp: ourlang.Function) -> wasm.Function: (k, v.wasm_type(), ) for k, v in wgn.locals.items() ], - type_(inp.returns), + type3(inp.returns_type3), wgn.statements ) @@ -555,38 +681,47 @@ def module_data(inp: ourlang.ModuleData) -> bytes: for block in inp.blocks: block.address = unalloc_ptr + 4 # 4 bytes for allocator header - data_list = [] + data_list: List[bytes] = [] for constant in block.data: - if isinstance(constant, ourlang.ConstantUInt8): + assert isinstance(constant.type3, type3types.Type3), (id(constant), type3types.TYPE3_ASSERTION_ERROR) + + if constant.type3 == type3types.u8: + assert isinstance(constant.value, int) data_list.append(module_data_u8(constant.value)) continue - if isinstance(constant, ourlang.ConstantUInt32): + if constant.type3 == type3types.u32: + assert isinstance(constant.value, int) data_list.append(module_data_u32(constant.value)) continue - if isinstance(constant, ourlang.ConstantUInt64): + if constant.type3 == type3types.u64: + assert isinstance(constant.value, int) data_list.append(module_data_u64(constant.value)) continue - if isinstance(constant, ourlang.ConstantInt32): + if constant.type3 == type3types.i32: + assert isinstance(constant.value, int) data_list.append(module_data_i32(constant.value)) continue - if isinstance(constant, ourlang.ConstantInt64): + if constant.type3 == type3types.i64: + assert isinstance(constant.value, int) data_list.append(module_data_i64(constant.value)) continue - if isinstance(constant, ourlang.ConstantFloat32): + if constant.type3 == type3types.f32: + assert isinstance(constant.value, float) data_list.append(module_data_f32(constant.value)) continue - if isinstance(constant, ourlang.ConstantFloat64): + if constant.type3 == type3types.f64: + assert isinstance(constant.value, float) data_list.append(module_data_f64(constant.value)) continue - raise NotImplementedError(constant) + raise NotImplementedError(constant, constant.type3, constant.value) block_data = b''.join(data_list) @@ -636,48 +771,70 @@ def module(inp: ourlang.Module) -> wasm.Module: return result -def _generate_tuple_constructor(wgn: WasmGenerator, inp: ourlang.TupleConstructor) -> None: - tmp_var = wgn.temp_var_i32('tuple_adr') - - # Allocated the required amounts of bytes in memory - wgn.i32.const(inp.tuple.alloc_size()) - wgn.call(stdlib_alloc.__alloc__) - wgn.local.set(tmp_var) - - # Store each member individually - for member in inp.tuple.members: - mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, member) - - wgn.local.get(tmp_var) - wgn.add_statement('local.get', f'$arg{member.idx}') - wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset)) - - # Return the allocated address - wgn.local.get(tmp_var) - def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None: tmp_var = wgn.temp_var_i32('struct_adr') # Allocated the required amounts of bytes in memory - wgn.i32.const(inp.struct.alloc_size()) + wgn.i32.const(_calculate_alloc_size(inp.struct_type3)) wgn.call(stdlib_alloc.__alloc__) wgn.local.set(tmp_var) # Store each member individually - for member in inp.struct.members: - mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) + for memname, mtyp3 in inp.struct_type3.members.items(): + mtyp = LOAD_STORE_TYPE_MAP.get(mtyp3.name) if mtyp is None: # In the future might extend this by having structs or tuples # as members of struct or tuples - raise NotImplementedError(expression, inp, member) + raise NotImplementedError(expression, inp, mtyp3) wgn.local.get(tmp_var) - wgn.add_statement('local.get', f'${member.name}') - wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset)) + wgn.add_statement('local.get', f'${memname}') + wgn.add_statement(f'{mtyp}.store', 'offset=' + str(_calculate_member_offset( + inp.struct_type3, memname + ))) # Return the allocated address wgn.local.get(tmp_var) + +def _calculate_alloc_size(typ: Union[type3types.StructType3, type3types.Type3]) -> int: + if typ == type3types.u8: + return 4 # FIXME: We allocate 4 bytes for every u8 since you load them into an i32 + + if typ in (type3types.u32, type3types.i32, type3types.f32, ): + return 4 + + if typ in (type3types.u64, type3types.i64, type3types.f64, ): + return 8 + + if isinstance(typ, type3types.StructType3): + return sum( + _calculate_alloc_size(x) + for x in typ.members.values() + ) + + if isinstance(typ, type3types.AppliedType3): + if typ.base is type3types.tuple: + size = 0 + for arg in typ.args: + assert not isinstance(arg, type3types.IntType3) + + if isinstance(arg, type3types.PlaceholderForType): + assert not arg.resolve_as is None + arg = arg.resolve_as + + size += _calculate_alloc_size(arg) + + return size + + raise NotImplementedError(_calculate_alloc_size, typ) + +def _calculate_member_offset(struct_type3: type3types.StructType3, member: str) -> int: + result = 0 + + for mem, memtyp in struct_type3.members.items(): + if member == mem: + return result + + result += _calculate_alloc_size(memtyp) + + raise Exception(f'{member} not in {struct_type3}') diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 5f19d2e..01430b9 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -1,128 +1,51 @@ """ Contains the syntax tree for ourlang """ -from typing import Dict, List, Tuple, Optional, Union +from typing import Dict, List, Optional, Union import enum from typing_extensions import Final -WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', ) -WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', ) +WEBASSEMBLY_BUILTIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', ) +WEBASSEMBLY_BUILTIN_BYTES_OPS: Final = ('len', ) -from .typing import ( - TypeBase, - TypeNone, - TypeBool, - TypeUInt8, TypeUInt32, TypeUInt64, - TypeInt32, TypeInt64, - TypeFloat32, TypeFloat64, - TypeBytes, - TypeTuple, TypeTupleMember, - TypeStaticArray, TypeStaticArrayMember, - TypeStruct, TypeStructMember, -) +from .type3 import types as type3types +from .type3.types import Type3, Type3OrPlaceholder, PlaceholderForType, StructType3 class Expression: """ An expression within a statement """ - __slots__ = ('type', ) + __slots__ = ('type3', ) - type: TypeBase + type3: Type3OrPlaceholder - def __init__(self, type_: TypeBase) -> None: - self.type = type_ + def __init__(self) -> None: + self.type3 = PlaceholderForType([self]) class Constant(Expression): """ An constant value expression within a statement + + # FIXME: Rename to literal """ __slots__ = () -class ConstantUInt8(Constant): +class ConstantPrimitive(Constant): """ - An UInt8 constant value expression within a statement + An primitive constant value expression within a statement """ __slots__ = ('value', ) - value: int + value: Union[int, float] - def __init__(self, type_: TypeUInt8, value: int) -> None: - super().__init__(type_) + def __init__(self, value: Union[int, float]) -> None: + super().__init__() self.value = value -class ConstantUInt32(Constant): - """ - An UInt32 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeUInt32, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantUInt64(Constant): - """ - An UInt64 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeUInt64, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantInt32(Constant): - """ - An Int32 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeInt32, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantInt64(Constant): - """ - An Int64 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeInt64, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantFloat32(Constant): - """ - An Float32 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: float - - def __init__(self, type_: TypeFloat32, value: float) -> None: - super().__init__(type_) - self.value = value - -class ConstantFloat64(Constant): - """ - An Float64 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: float - - def __init__(self, type_: TypeFloat64, value: float) -> None: - super().__init__(type_) - self.value = value + def __repr__(self) -> str: + return f'ConstantPrimitive({repr(self.value)})' class ConstantTuple(Constant): """ @@ -130,35 +53,43 @@ class ConstantTuple(Constant): """ __slots__ = ('value', ) - value: List[Constant] + value: List[ConstantPrimitive] - def __init__(self, type_: TypeTuple, value: List[Constant]) -> None: - super().__init__(type_) + def __init__(self, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples? + super().__init__() self.value = value -class ConstantStaticArray(Constant): - """ - A StaticArray constant value expression within a statement - """ - __slots__ = ('value', ) + def __repr__(self) -> str: + return f'ConstantTuple({repr(self.value)})' - value: List[Constant] +class ConstantStruct(Constant): + """ + A Struct constant value expression within a statement + """ + __slots__ = ('struct_name', 'value', ) - def __init__(self, type_: TypeStaticArray, value: List[Constant]) -> None: - super().__init__(type_) + struct_name: str + value: List[ConstantPrimitive] + + def __init__(self, struct_name: str, value: List[ConstantPrimitive]) -> None: # FIXME: Struct of structs? + super().__init__() + self.struct_name = struct_name self.value = value + def __repr__(self) -> str: + return f'ConstantStruct({repr(self.struct_name)}, {repr(self.value)})' + class VariableReference(Expression): """ An variable reference expression within a statement """ - __slots__ = ('name', ) + __slots__ = ('variable', ) - name: str + variable: Union['ModuleConstantDef', 'FunctionParam'] # also possibly local - def __init__(self, type_: TypeBase, name: str) -> None: - super().__init__(type_) - self.name = name + def __init__(self, variable: Union['ModuleConstantDef', 'FunctionParam']) -> None: + super().__init__() + self.variable = variable class UnaryOp(Expression): """ @@ -169,8 +100,8 @@ class UnaryOp(Expression): operator: str right: Expression - def __init__(self, type_: TypeBase, operator: str, right: Expression) -> None: - super().__init__(type_) + def __init__(self, operator: str, right: Expression) -> None: + super().__init__() self.operator = operator self.right = right @@ -185,13 +116,16 @@ class BinaryOp(Expression): left: Expression right: Expression - def __init__(self, type_: TypeBase, operator: str, left: Expression, right: Expression) -> None: - super().__init__(type_) + def __init__(self, operator: str, left: Expression, right: Expression) -> None: + super().__init__() self.operator = operator self.left = left self.right = right + def __repr__(self) -> str: + return f'BinaryOp({repr(self.operator)}, {repr(self.left)}, {repr(self.right)})' + class FunctionCall(Expression): """ A function call expression within a statement @@ -202,22 +136,36 @@ class FunctionCall(Expression): arguments: List[Expression] def __init__(self, function: 'Function') -> None: - super().__init__(function.returns) + super().__init__() self.function = function self.arguments = [] -class AccessBytesIndex(Expression): +class TupleInstantiation(Expression): """ - Access a bytes index for reading + Instantiation a tuple + """ + __slots__ = ('elements', ) + + elements: List[Expression] + + def __init__(self, elements: List[Expression]) -> None: + super().__init__() + + self.elements = elements + +class Subscript(Expression): + """ + A subscript, for example to refer to a static array or tuple + by index """ __slots__ = ('varref', 'index', ) varref: VariableReference index: Expression - def __init__(self, type_: TypeBase, varref: VariableReference, index: Expression) -> None: - super().__init__(type_) + def __init__(self, varref: VariableReference, index: Expression) -> None: + super().__init__() self.varref = varref self.index = index @@ -226,47 +174,17 @@ class AccessStructMember(Expression): """ Access a struct member for reading of writing """ - __slots__ = ('varref', 'member', ) + __slots__ = ('varref', 'struct_type3', 'member', ) varref: VariableReference - member: TypeStructMember + struct_type3: StructType3 + member: str - def __init__(self, varref: VariableReference, member: TypeStructMember) -> None: - super().__init__(member.type) + def __init__(self, varref: VariableReference, struct_type3: StructType3, member: str) -> None: + super().__init__() self.varref = varref - self.member = member - -class AccessTupleMember(Expression): - """ - Access a tuple member for reading of writing - """ - __slots__ = ('varref', 'member', ) - - varref: VariableReference - member: TypeTupleMember - - def __init__(self, varref: VariableReference, member: TypeTupleMember, ) -> None: - super().__init__(member.type) - - self.varref = varref - self.member = member - -class AccessStaticArrayMember(Expression): - """ - Access a tuple member for reading of writing - """ - __slots__ = ('varref', 'static_array', 'member', ) - - varref: Union['ModuleConstantReference', VariableReference] - static_array: TypeStaticArray - member: Union[Expression, TypeStaticArrayMember] - - def __init__(self, varref: Union['ModuleConstantReference', VariableReference], static_array: TypeStaticArray, member: Union[TypeStaticArrayMember, Expression], ) -> None: - super().__init__(static_array.member_type) - - self.varref = varref - self.static_array = static_array + self.struct_type3 = struct_type3 self.member = member class Fold(Expression): @@ -287,31 +205,18 @@ class Fold(Expression): def __init__( self, - type_: TypeBase, dir_: Direction, func: 'Function', base: Expression, iter_: Expression, ) -> None: - super().__init__(type_) + super().__init__() self.dir = dir_ self.func = func self.base = base self.iter = iter_ -class ModuleConstantReference(Expression): - """ - An reference to a module constant expression within a statement - """ - __slots__ = ('definition', ) - - definition: 'ModuleConstantDef' - - def __init__(self, type_: TypeBase, definition: 'ModuleConstantDef') -> None: - super().__init__(type_) - self.definition = definition - class Statement: """ A statement within a function @@ -348,20 +253,31 @@ class StatementIf(Statement): self.statements = [] self.else_statements = [] -FunctionParam = Tuple[str, TypeBase] +class FunctionParam: + """ + A parameter for a Function + """ + __slots__ = ('name', 'type3', ) + + name: str + type3: Type3OrPlaceholder + + def __init__(self, name: str, type3: Optional[Type3]) -> None: + self.name = name + self.type3 = PlaceholderForType([self]) if type3 is None else type3 class Function: """ A function processes input and produces output """ - __slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'posonlyargs', ) + __slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns_type3', 'posonlyargs', ) name: str lineno: int exported: bool imported: bool statements: List[Statement] - returns: TypeBase + returns_type3: Type3 posonlyargs: List[FunctionParam] def __init__(self, name: str, lineno: int) -> None: @@ -370,9 +286,22 @@ class Function: self.exported = False self.imported = False self.statements = [] - self.returns = TypeNone() + self.returns_type3 = type3types.none # FIXME: This could be a placeholder self.posonlyargs = [] +class StructDefinition: + """ + The definition for a struct + """ + __slots__ = ('struct_type3', 'lineno', ) + + struct_type3: StructType3 + lineno: int + + def __init__(self, struct_type3: StructType3, lineno: int) -> None: + self.struct_type3 = struct_type3 + self.lineno = lineno + class StructConstructor(Function): """ The constructor method for a struct @@ -380,56 +309,36 @@ class StructConstructor(Function): A function will generated to instantiate a struct. The arguments will be the defaults """ - __slots__ = ('struct', ) + __slots__ = ('struct_type3', ) - struct: TypeStruct + struct_type3: StructType3 - def __init__(self, struct: TypeStruct) -> None: - super().__init__(f'@{struct.name}@__init___@', -1) + def __init__(self, struct_type3: StructType3) -> None: + super().__init__(f'@{struct_type3.name}@__init___@', -1) - self.returns = struct + self.returns_type3 = struct_type3 - for mem in struct.members: - self.posonlyargs.append((mem.name, mem.type, )) + for mem, typ in struct_type3.members.items(): + self.posonlyargs.append(FunctionParam(mem, typ, )) - self.struct = struct - -class TupleConstructor(Function): - """ - The constructor method for a tuple - """ - __slots__ = ('tuple', ) - - tuple: TypeTuple - - def __init__(self, tuple_: TypeTuple) -> None: - name = tuple_.render_internal_name() - - super().__init__(f'@{name}@__init___@', -1) - - self.returns = tuple_ - - for mem in tuple_.members: - self.posonlyargs.append((f'arg{mem.idx}', mem.type, )) - - self.tuple = tuple_ + self.struct_type3 = struct_type3 class ModuleConstantDef: """ A constant definition within a module """ - __slots__ = ('name', 'lineno', 'type', 'constant', 'data_block', ) + __slots__ = ('name', 'lineno', 'type3', 'constant', 'data_block', ) name: str lineno: int - type: TypeBase + type3: Type3 constant: Constant data_block: Optional['ModuleDataBlock'] - def __init__(self, name: str, lineno: int, type_: TypeBase, constant: Constant, data_block: Optional['ModuleDataBlock']) -> None: + def __init__(self, name: str, lineno: int, type3: Type3, constant: Constant, data_block: Optional['ModuleDataBlock']) -> None: self.name = name self.lineno = lineno - self.type = type_ + self.type3 = type3 self.constant = constant self.data_block = data_block @@ -439,10 +348,10 @@ class ModuleDataBlock: """ __slots__ = ('data', 'address', ) - data: List[Constant] + data: List[ConstantPrimitive] address: Optional[int] - def __init__(self, data: List[Constant]) -> None: + def __init__(self, data: List[ConstantPrimitive]) -> None: self.data = data self.address = None @@ -461,27 +370,15 @@ class Module: """ A module is a file and consists of functions """ - __slots__ = ('data', 'types', 'structs', 'constant_defs', 'functions',) + __slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions',) data: ModuleData - types: Dict[str, TypeBase] - structs: Dict[str, TypeStruct] + struct_definitions: Dict[str, StructDefinition] constant_defs: Dict[str, ModuleConstantDef] functions: Dict[str, Function] def __init__(self) -> None: - self.types = { - 'None': TypeNone(), - 'u8': TypeUInt8(), - 'u32': TypeUInt32(), - 'u64': TypeUInt64(), - 'i32': TypeInt32(), - 'i64': TypeInt64(), - 'f32': TypeFloat32(), - 'f64': TypeFloat64(), - 'bytes': TypeBytes(), - } self.data = ModuleData() - self.structs = {} + self.struct_definitions = {} self.constant_defs = {} self.functions = {} diff --git a/phasm/parser.py b/phasm/parser.py index d95bfce..5ee9d56 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -5,49 +5,30 @@ from typing import Any, Dict, NoReturn, Union import ast -from .typing import ( - TypeBase, - TypeUInt8, - TypeUInt32, - TypeUInt64, - TypeInt32, - TypeInt64, - TypeFloat32, - TypeFloat64, - TypeBytes, - TypeStruct, - TypeStructMember, - TypeTuple, - TypeTupleMember, - TypeStaticArray, - TypeStaticArrayMember, -) +from .type3 import types as type3types -from . import codestyle from .exceptions import StaticError from .ourlang import ( - WEBASSEMBLY_BUILDIN_FLOAT_OPS, + WEBASSEMBLY_BUILTIN_FLOAT_OPS, Module, ModuleDataBlock, Function, Expression, - AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember, BinaryOp, - Constant, - ConstantFloat32, ConstantFloat64, ConstantInt32, ConstantInt64, - ConstantUInt8, ConstantUInt32, ConstantUInt64, - ConstantTuple, ConstantStaticArray, + ConstantPrimitive, ConstantTuple, ConstantStruct, + TupleInstantiation, - FunctionCall, - StructConstructor, TupleConstructor, + FunctionCall, AccessStructMember, Subscript, + StructDefinition, StructConstructor, UnaryOp, VariableReference, - Fold, ModuleConstantReference, + Fold, Statement, StatementIf, StatementPass, StatementReturn, + FunctionParam, ModuleConstantDef, ) @@ -60,7 +41,7 @@ def phasm_parse(source: str) -> Module: our_visitor = OurVisitor() return our_visitor.visit_Module(res) -OurLocals = Dict[str, TypeBase] +OurLocals = Dict[str, Union[FunctionParam]] # FIXME: Does it become easier if we add ModuleConstantDef to this dict? class OurVisitor: """ @@ -95,14 +76,14 @@ class OurVisitor: module.constant_defs[res.name] = res - if isinstance(res, TypeStruct): - if res.name in module.structs: + if isinstance(res, StructDefinition): + if res.struct_type3.name in module.struct_definitions: raise StaticError( - f'{res.name} already defined on line {module.structs[res.name].lineno}' + f'{res.struct_type3.name} already defined on line {module.struct_definitions[res.struct_type3.name].lineno}' ) - module.structs[res.name] = res - constructor = StructConstructor(res) + module.struct_definitions[res.struct_type3.name] = res + constructor = StructConstructor(res.struct_type3) module.functions[constructor.name] = constructor if isinstance(res, Function): @@ -120,7 +101,7 @@ class OurVisitor: return module - def pre_visit_Module_stmt(self, module: Module, node: ast.stmt) -> Union[Function, TypeStruct, ModuleConstantDef]: + def pre_visit_Module_stmt(self, module: Module, node: ast.stmt) -> Union[Function, StructDefinition, ModuleConstantDef]: if isinstance(node, ast.FunctionDef): return self.pre_visit_Module_FunctionDef(module, node) @@ -138,12 +119,9 @@ class OurVisitor: _not_implemented(not node.args.posonlyargs, 'FunctionDef.args.posonlyargs') for arg in node.args.args: - if not arg.annotation: - _raise_static_error(node, 'Type is required') - - function.posonlyargs.append(( + function.posonlyargs.append(FunctionParam( arg.arg, - self.visit_type(module, arg.annotation), + self.visit_type(module, arg.annotation) if arg.annotation else None, )) _not_implemented(not node.args.vararg, 'FunctionDef.args.vararg') @@ -166,21 +144,23 @@ class OurVisitor: else: function.imported = True - if node.returns: - function.returns = self.visit_type(module, node.returns) + if node.returns is not None: # Note: `-> None` would be a ast.Constant + function.returns_type3 = self.visit_type(module, node.returns) + else: + # FIXME: Mostly works already, needs to fix Function.returns_type3 and have it updated + raise NotImplementedError('Function without an explicit return type') _not_implemented(not node.type_comment, 'FunctionDef.type_comment') return function - def pre_visit_Module_ClassDef(self, module: Module, node: ast.ClassDef) -> TypeStruct: - struct = TypeStruct(node.name, node.lineno) + def pre_visit_Module_ClassDef(self, module: Module, node: ast.ClassDef) -> StructDefinition: _not_implemented(not node.bases, 'ClassDef.bases') _not_implemented(not node.keywords, 'ClassDef.keywords') _not_implemented(not node.decorator_list, 'ClassDef.decorator_list') - offset = 0 + members: Dict[str, type3types.Type3] = {} for stmt in node.body: if not isinstance(stmt, ast.AnnAssign): @@ -195,47 +175,36 @@ class OurVisitor: if stmt.simple != 1: raise NotImplementedError('Class with non-simple arguments') - member = TypeStructMember(stmt.target.id, self.visit_type(module, stmt.annotation), offset) + if stmt.target.id in members: + _raise_static_error(stmt, 'Struct members must have unique names') - struct.members.append(member) - offset += member.type.alloc_size() + members[stmt.target.id] = self.visit_type(module, stmt.annotation) - return struct + return StructDefinition(type3types.StructType3(node.name, members), node.lineno) def pre_visit_Module_AnnAssign(self, module: Module, node: ast.AnnAssign) -> ModuleConstantDef: if not isinstance(node.target, ast.Name): - _raise_static_error(node, 'Must be name') + _raise_static_error(node.target, 'Must be name') if not isinstance(node.target.ctx, ast.Store): - _raise_static_error(node, 'Must be load context') + _raise_static_error(node.target, 'Must be store context') - exp_type = self.visit_type(module, node.annotation) - - if isinstance(exp_type, TypeInt32): - if not isinstance(node.value, ast.Constant): - _raise_static_error(node, 'Must be constant') - - constant = ModuleConstantDef( + if isinstance(node.value, ast.Constant): + type3 = self.visit_type(module, node.annotation) + return ModuleConstantDef( node.target.id, node.lineno, - exp_type, - self.visit_Module_Constant(module, exp_type, node.value), + type3, + self.visit_Module_Constant(module, node.value), None, ) - return constant - - if isinstance(exp_type, TypeTuple): - if not isinstance(node.value, ast.Tuple): - _raise_static_error(node, 'Must be tuple') - - if len(exp_type.members) != len(node.value.elts): - _raise_static_error(node, 'Invalid number of tuple values') + if isinstance(node.value, ast.Tuple): tuple_data = [ - self.visit_Module_Constant(module, mem.type, arg_node) - for arg_node, mem in zip(node.value.elts, exp_type.members) + self.visit_Module_Constant(module, arg_node) + for arg_node in node.value.elts if isinstance(arg_node, ast.Constant) ] - if len(exp_type.members) != len(tuple_data): + if len(node.value.elts) != len(tuple_data): _raise_static_error(node, 'Tuple arguments must be constants') # Allocate the data @@ -246,36 +215,47 @@ class OurVisitor: return ModuleConstantDef( node.target.id, node.lineno, - exp_type, - ConstantTuple(exp_type, tuple_data), + self.visit_type(module, node.annotation), + ConstantTuple(tuple_data), data_block, ) - if isinstance(exp_type, TypeStaticArray): - if not isinstance(node.value, ast.Tuple): - _raise_static_error(node, 'Must be static array') + if isinstance(node.value, ast.Call): + # Struct constant + # Stored in memory like a tuple, so much of the code is the same - if len(exp_type.members) != len(node.value.elts): - _raise_static_error(node, 'Invalid number of static array values') + if not isinstance(node.value.func, ast.Name): + _raise_static_error(node.value.func, 'Must be name') + if not isinstance(node.value.func.ctx, ast.Load): + _raise_static_error(node.value.func, 'Must be load context') - static_array_data = [ - self.visit_Module_Constant(module, exp_type.member_type, arg_node) - for arg_node in node.value.elts + if not node.value.func.id in module.struct_definitions: + _raise_static_error(node.value.func, 'Undefined struct') + + if node.value.keywords: + _raise_static_error(node.value.func, 'Cannot use keywords') + + if not isinstance(node.annotation, ast.Name): + _raise_static_error(node.annotation, 'Must be name') + + struct_data = [ + self.visit_Module_Constant(module, arg_node) + for arg_node in node.value.args if isinstance(arg_node, ast.Constant) ] - if len(exp_type.members) != len(static_array_data): - _raise_static_error(node, 'Static array arguments must be constants') + if len(node.value.args) != len(struct_data): + _raise_static_error(node, 'Struct arguments must be constants') # Allocate the data - data_block = ModuleDataBlock(static_array_data) + data_block = ModuleDataBlock(struct_data) module.data.blocks.append(data_block) # Then return the constant as a pointer return ModuleConstantDef( node.target.id, node.lineno, - exp_type, - ConstantStaticArray(exp_type, static_array_data), + self.visit_type(module, node.annotation), + ConstantStruct(node.value.func.id, struct_data), data_block, ) @@ -297,7 +277,10 @@ class OurVisitor: def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None: function = module.functions[node.name] - our_locals = dict(function.posonlyargs) + our_locals: OurLocals = { + x.name: x + for x in function.posonlyargs + } for stmt in node.body: function.statements.append( @@ -311,12 +294,12 @@ class OurVisitor: _raise_static_error(node, 'Return must have an argument') return StatementReturn( - self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.value) + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.value) ) if isinstance(node, ast.If): result = StatementIf( - self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.test) + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.test) ) for stmt in node.body: @@ -336,7 +319,7 @@ class OurVisitor: raise NotImplementedError(f'{node} as stmt in FunctionDef') - def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.expr) -> Expression: + def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, node: ast.expr) -> Expression: if isinstance(node, ast.BinOp): if isinstance(node.op, ast.Add): operator = '+' @@ -344,6 +327,8 @@ class OurVisitor: operator = '-' elif isinstance(node.op, ast.Mult): operator = '*' + elif isinstance(node.op, ast.Div): + operator = '/' elif isinstance(node.op, ast.LShift): operator = '<<' elif isinstance(node.op, ast.RShift): @@ -357,14 +342,10 @@ class OurVisitor: else: raise NotImplementedError(f'Operator {node.op}') - # Assume the type doesn't change when descending into a binary operator - # e.g. you can do `"hello" * 3` with the code below (yet) - return BinaryOp( - exp_type, operator, - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right), ) if isinstance(node, ast.UnaryOp): @@ -376,9 +357,8 @@ class OurVisitor: raise NotImplementedError(f'Operator {node.op}') return UnaryOp( - exp_type, operator, - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.operand), ) if isinstance(node, ast.Compare): @@ -394,32 +374,28 @@ class OurVisitor: else: raise NotImplementedError(f'Operator {node.ops}') - # Assume the type doesn't change when descending into a binary operator - # e.g. you can do `"hello" * 3` with the code below (yet) - return BinaryOp( - exp_type, operator, - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]), ) if isinstance(node, ast.Call): - return self.visit_Module_FunctionDef_Call(module, function, our_locals, exp_type, node) + return self.visit_Module_FunctionDef_Call(module, function, our_locals, node) if isinstance(node, ast.Constant): return self.visit_Module_Constant( - module, exp_type, node, + module, node, ) if isinstance(node, ast.Attribute): return self.visit_Module_FunctionDef_Attribute( - module, function, our_locals, exp_type, node, + module, function, our_locals, node, ) if isinstance(node, ast.Subscript): return self.visit_Module_FunctionDef_Subscript( - module, function, our_locals, exp_type, node, + module, function, our_locals, node, ) if isinstance(node, ast.Name): @@ -427,45 +403,30 @@ class OurVisitor: _raise_static_error(node, 'Must be load context') if node.id in our_locals: - act_type = our_locals[node.id] - if exp_type != act_type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(act_type)}') - - return VariableReference(act_type, node.id) + param = our_locals[node.id] + return VariableReference(param) if node.id in module.constant_defs: cdef = module.constant_defs[node.id] - if exp_type != cdef.type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(cdef.type)}') - - return ModuleConstantReference(exp_type, cdef) + return VariableReference(cdef) _raise_static_error(node, f'Undefined variable {node.id}') if isinstance(node, ast.Tuple): - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') + arguments = [ + self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_node) + for arg_node in node.elts + if isinstance(arg_node, ast.Constant) + ] - if isinstance(exp_type, TypeTuple): - if len(exp_type.members) != len(node.elts): - _raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given') + if len(arguments) != len(node.elts): + raise NotImplementedError('Non-constant tuple members') - tuple_constructor = TupleConstructor(exp_type) - - func = module.functions[tuple_constructor.name] - - result = FunctionCall(func) - result.arguments = [ - self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node) - for arg_node, mem in zip(node.elts, exp_type.members) - ] - return result - - _raise_static_error(node, f'Expression is expecting a {codestyle.type_(exp_type)}, not a tuple') + return TupleInstantiation(arguments) raise NotImplementedError(f'{node} as expr in FunctionDef') - def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]: + def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]: if node.keywords: _raise_static_error(node, 'Keyword calling not supported') # Yet? @@ -474,59 +435,42 @@ class OurVisitor: if not isinstance(node.func.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if node.func.id in module.structs: - struct = module.structs[node.func.id] - struct_constructor = StructConstructor(struct) + if node.func.id in module.struct_definitions: + struct_definition = module.struct_definitions[node.func.id] + struct_constructor = StructConstructor(struct_definition.struct_type3) + + # FIXME: Defer struct de-allocation func = module.functions[struct_constructor.name] - elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: - if not isinstance(exp_type, (TypeFloat32, TypeFloat64, )): - _raise_static_error(node, f'Cannot make {node.func.id} result in {codestyle.type_(exp_type)}') - + elif node.func.id in WEBASSEMBLY_BUILTIN_FLOAT_OPS: if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') return UnaryOp( - exp_type, 'sqrt', - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]), ) elif node.func.id == 'u32': - if not isinstance(exp_type, TypeUInt32): - _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') - if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') - # FIXME: This is a stub, proper casting is todo - return UnaryOp( - exp_type, 'cast', - self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['u8'], node.args[0]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]), ) elif node.func.id == 'len': - if not isinstance(exp_type, TypeInt32): - _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') - if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') return UnaryOp( - exp_type, 'len', - self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]), ) elif node.func.id == 'foldl': - # TODO: This should a much more generic function! - # For development purposes, we're assuming you're doing a foldl(Callable[[u8, u8], u8], u8, bytes) - # In the future, we should probably infer the type of the second argument, - # and use it as expected types for the other u8s and the Iterable[u8] (i.e. bytes) - if 3 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given') - # TODO: This is not generic + # TODO: This is not generic, you cannot return a function subnode = node.args[0] if not isinstance(subnode, ast.Name): raise NotImplementedError(f'Calling methods that are not a name {subnode}') @@ -538,21 +482,11 @@ class OurVisitor: if 2 != len(func.posonlyargs): _raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given') - if exp_type.__class__ != func.returns.__class__: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') - - if func.returns.__class__ != func.posonlyargs[0][1].__class__: - _raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0][1])}') - - if module.types['u8'].__class__ != func.posonlyargs[1][1].__class__: - _raise_static_error(node, 'Only folding over bytes (u8) is supported at this time') - return Fold( - exp_type, Fold.Direction.LEFT, func, - self.visit_Module_FunctionDef_expr(module, function, our_locals, func.returns, node.args[1]), - self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[2]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]), ) else: if node.func.id not in module.functions: @@ -560,206 +494,76 @@ class OurVisitor: func = module.functions[node.func.id] - if func.returns != exp_type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') - if len(func.posonlyargs) != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given') result = FunctionCall(func) result.arguments.extend( - self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_type, arg_expr) - for arg_expr, (_, arg_type) in zip(node.args, func.posonlyargs) + self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr) + for arg_expr, param in zip(node.args, func.posonlyargs) ) return result - def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Attribute) -> Expression: - del module - del function - + def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression: if not isinstance(node.value, ast.Name): _raise_static_error(node, 'Must reference a name') if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if not node.value.id in our_locals: - _raise_static_error(node, f'Undefined variable {node.value.id}') + varref = self.visit_Module_FunctionDef_expr(module, function, our_locals, node.value) + if not isinstance(varref, VariableReference): + _raise_static_error(node.value, 'Must refer to variable') - node_typ = our_locals[node.value.id] - if not isinstance(node_typ, TypeStruct): - _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') - - member = node_typ.get_member(node.attr) - if member is None: - _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}') - - if exp_type != member.type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}') + if not isinstance(varref.variable.type3, type3types.StructType3): + _raise_static_error(node.value, 'Must refer to struct') return AccessStructMember( - VariableReference(node_typ, node.value.id), - member, + varref, + varref.variable.type3, + node.attr, ) - def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Subscript) -> Expression: + def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression: if not isinstance(node.value, ast.Name): _raise_static_error(node, 'Must reference a name') - if not isinstance(node.slice, ast.Index): + if isinstance(node.slice, ast.Slice): _raise_static_error(node, 'Must subscript using an index') if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - varref: Union[ModuleConstantReference, VariableReference] + varref: VariableReference if node.value.id in our_locals: - node_typ = our_locals[node.value.id] - varref = VariableReference(node_typ, node.value.id) + param = our_locals[node.value.id] + varref = VariableReference(param) elif node.value.id in module.constant_defs: constant_def = module.constant_defs[node.value.id] - node_typ = constant_def.type - varref = ModuleConstantReference(node_typ, constant_def) + varref = VariableReference(constant_def) else: _raise_static_error(node, f'Undefined variable {node.value.id}') slice_expr = self.visit_Module_FunctionDef_expr( - module, function, our_locals, module.types['u32'], node.slice.value, + module, function, our_locals, node.slice, ) - if isinstance(node_typ, TypeBytes): - t_u8 = module.types['u8'] - if exp_type != t_u8: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{codestyle.expression(slice_expr)}] is actually {codestyle.type_(t_u8)}') + return Subscript(varref, slice_expr) - if isinstance(varref, ModuleConstantReference): - raise NotImplementedError(f'{node} from module constant') - - return AccessBytesIndex( - t_u8, - varref, - slice_expr, - ) - - if isinstance(node_typ, TypeTuple): - if not isinstance(slice_expr, ConstantUInt32): - _raise_static_error(node, 'Must subscript using a constant index') - - idx = slice_expr.value - - if len(node_typ.members) <= idx: - _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') - - tuple_member = node_typ.members[idx] - if exp_type != tuple_member.type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(tuple_member.type)}') - - if isinstance(varref, ModuleConstantReference): - raise NotImplementedError(f'{node} from module constant') - - return AccessTupleMember( - varref, - tuple_member, - ) - - if isinstance(node_typ, TypeStaticArray): - if exp_type != node_typ.member_type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}') - - if not isinstance(slice_expr, ConstantInt32): - return AccessStaticArrayMember( - varref, - node_typ, - slice_expr, - ) - - idx = slice_expr.value - - if len(node_typ.members) <= idx: - _raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}') - - static_array_member = node_typ.members[idx] - - return AccessStaticArrayMember( - varref, - node_typ, - static_array_member, - ) - - _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}') - - def visit_Module_Constant(self, module: Module, exp_type: TypeBase, node: ast.Constant) -> Constant: + def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive: del module _not_implemented(node.kind is None, 'Constant.kind') - if isinstance(exp_type, TypeUInt8): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') + if isinstance(node.value, (int, float, )): + return ConstantPrimitive(node.value) - if node.value < 0 or node.value > 255: - _raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}') + raise NotImplementedError(f'{node.value} as constant') - return ConstantUInt8(exp_type, node.value) - - if isinstance(exp_type, TypeUInt32): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < 0 or node.value > 4294967295: - _raise_static_error(node, 'Integer value out of range') - - return ConstantUInt32(exp_type, node.value) - - if isinstance(exp_type, TypeUInt64): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < 0 or node.value > 18446744073709551615: - _raise_static_error(node, 'Integer value out of range') - - return ConstantUInt64(exp_type, node.value) - - if isinstance(exp_type, TypeInt32): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < -2147483648 or node.value > 2147483647: - _raise_static_error(node, 'Integer value out of range') - - return ConstantInt32(exp_type, node.value) - - if isinstance(exp_type, TypeInt64): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < -9223372036854775808 or node.value > 9223372036854775807: - _raise_static_error(node, 'Integer value out of range') - - return ConstantInt64(exp_type, node.value) - - if isinstance(exp_type, TypeFloat32): - if not isinstance(node.value, (float, int, )): - _raise_static_error(node, 'Expected float value') - - # FIXME: Range check - - return ConstantFloat32(exp_type, node.value) - - if isinstance(exp_type, TypeFloat64): - if not isinstance(node.value, (float, int, )): - _raise_static_error(node, 'Expected float value') - - # FIXME: Range check - - return ConstantFloat64(exp_type, node.value) - - raise NotImplementedError(f'{node} as const for type {exp_type}') - - def visit_type(self, module: Module, node: ast.expr) -> TypeBase: + def visit_type(self, module: Module, node: ast.expr) -> type3types.Type3: if isinstance(node, ast.Constant): if node.value is None: - return module.types['None'] + return type3types.none _raise_static_error(node, f'Unrecognized type {node.value}') @@ -767,70 +571,42 @@ class OurVisitor: if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if node.id in module.types: - return module.types[node.id] + if node.id in type3types.LOOKUP_TABLE: + return type3types.LOOKUP_TABLE[node.id] - if node.id in module.structs: - return module.structs[node.id] + if node.id in module.struct_definitions: + return module.struct_definitions[node.id].struct_type3 _raise_static_error(node, f'Unrecognized type {node.id}') if isinstance(node, ast.Subscript): if not isinstance(node.value, ast.Name): _raise_static_error(node, 'Must be name') - if not isinstance(node.slice, ast.Index): + if isinstance(node.slice, ast.Slice): _raise_static_error(node, 'Must subscript using an index') - if not isinstance(node.slice.value, ast.Constant): + if not isinstance(node.slice, ast.Constant): _raise_static_error(node, 'Must subscript using a constant index') - if not isinstance(node.slice.value.value, int): + if not isinstance(node.slice.value, int): _raise_static_error(node, 'Must subscript using a constant integer index') if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if node.value.id in module.types: - member_type = module.types[node.value.id] - else: + if node.value.id not in type3types.LOOKUP_TABLE: # FIXME: Tuple of tuples? _raise_static_error(node, f'Unrecognized type {node.value.id}') - type_static_array = TypeStaticArray(member_type) - - offset = 0 - - for idx in range(node.slice.value.value): - static_array_member = TypeStaticArrayMember(idx, offset) - - type_static_array.members.append(static_array_member) - offset += member_type.alloc_size() - - key = f'{node.value.id}[{node.slice.value.value}]' - - if key not in module.types: - module.types[key] = type_static_array - - return module.types[key] + return type3types.AppliedType3( + type3types.static_array, + [self.visit_type(module, node.value), type3types.IntType3(node.slice.value)], + ) if isinstance(node, ast.Tuple): if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - type_tuple = TypeTuple() - - offset = 0 - - for idx, elt in enumerate(node.elts): - tuple_member = TypeTupleMember(idx, self.visit_type(module, elt), offset) - - type_tuple.members.append(tuple_member) - offset += tuple_member.type.alloc_size() - - key = type_tuple.render_internal_name() - - if key not in module.types: - module.types[key] = type_tuple - constructor = TupleConstructor(type_tuple) - module.functions[constructor.name] = constructor - - return module.types[key] + return type3types.AppliedType3( + type3types.tuple, + (self.visit_type(module, elt) for elt in node.elts) + ) raise NotImplementedError(f'{node} as type') diff --git a/phasm/stdlib/alloc.py b/phasm/stdlib/alloc.py index 2761bfb..8c5742d 100644 --- a/phasm/stdlib/alloc.py +++ b/phasm/stdlib/alloc.py @@ -26,7 +26,7 @@ def __find_free_block__(g: Generator, alloc_size: i32) -> i32: g.i32.const(0) g.return_() - del alloc_size # TODO + del alloc_size # TODO: Actual implement using a previously freed block g.unreachable() return i32('return') # To satisfy mypy diff --git a/phasm/type3/__init__.py b/phasm/type3/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/phasm/type3/constraints.py b/phasm/type3/constraints.py new file mode 100644 index 0000000..31b8f20 --- /dev/null +++ b/phasm/type3/constraints.py @@ -0,0 +1,549 @@ +""" +This module contains possible constraints generated based on the AST + +These need to be resolved before the program can be compiled. +""" +from typing import Dict, Optional, List, Tuple, Union + +from .. import ourlang + +from . import types + +class Error: + """ + An error returned by the check functions for a contraint + + This means the programmer has to make some kind of chance to the + typing of their program before the compiler can do its thing. + """ + def __init__(self, msg: str) -> None: + self.msg = msg + + def __repr__(self) -> str: + return f'Error({repr(self.msg)})' + +class RequireTypeSubstitutes: + """ + Returned by the check function for a contraint if they do not have all + their types substituted yet. + + Hopefully, another constraint will give the right information about the + typing of the program, so this constraint can be updated. + """ + +SubstitutionMap = Dict[types.PlaceholderForType, types.Type3] + +NewConstraintList = List['ConstraintBase'] + +CheckResult = Union[None, SubstitutionMap, Error, NewConstraintList, RequireTypeSubstitutes] + +HumanReadableRet = Tuple[str, Dict[str, Union[str, ourlang.Expression, types.Type3, types.PlaceholderForType]]] + +class Context: + """ + Context for constraints + """ + + __slots__ = () + +class ConstraintBase: + """ + Base class for constraints + """ + __slots__ = ('comment', ) + + comment: Optional[str] + """ + A comment to help the programmer with debugging the types in their program + """ + + def __init__(self, comment: Optional[str] = None) -> None: + self.comment = comment + + def check(self) -> CheckResult: + """ + Checks if the constraint hold + + This function can return an error, if the constraint does not hold, + which indicates an error in the typing of the input program. + + This function can return RequireTypeSubstitutes(), if we cannot deduce + all the types yet. + + This function can return a SubstitutionMap, if during the evaluation + of the contraint we discovered new types. In this case, the constraint + is expected to hold. + + This function can return None, if the constraint holds, but no new + information was deduced from evaluating this constraint. + """ + raise NotImplementedError(self.__class__, self.check) + + def human_readable(self) -> HumanReadableRet: + """ + Returns a more human readable form of this constraint + """ + return repr(self), {} + +class SameTypeConstraint(ConstraintBase): + """ + Verifies that a number of types all are the same type + """ + __slots__ = ('type_list', ) + + type_list: List[types.Type3OrPlaceholder] + + def __init__(self, *type_list: types.Type3OrPlaceholder, comment: Optional[str] = None) -> None: + super().__init__(comment=comment) + + assert len(type_list) > 1 + self.type_list = [*type_list] + + def check(self) -> CheckResult: + known_types: List[types.Type3] = [] + placeholders = [] + do_applied_placeholder_check: bool = False + for typ in self.type_list: + if isinstance(typ, types.IntType3): + known_types.append(typ) + continue + + if isinstance(typ, (types.PrimitiveType3, types.StructType3, )): + known_types.append(typ) + continue + + if isinstance(typ, types.AppliedType3): + known_types.append(typ) + do_applied_placeholder_check = True + continue + + if isinstance(typ, types.PlaceholderForType): + if typ.resolve_as is not None: + known_types.append(typ.resolve_as) + else: + placeholders.append(typ) + continue + + raise NotImplementedError(typ) + + if not known_types: + return RequireTypeSubstitutes() + + new_constraint_list: List[ConstraintBase] = [] + + first_type = known_types[0] + for typ in known_types[1:]: + if isinstance(first_type, types.AppliedType3) and isinstance(typ, types.AppliedType3): + if len(first_type.args) != len(typ.args): + return Error('Mismatch between applied types argument count') + + if first_type.base != typ.base: + return Error('Mismatch between applied types base') + + for first_type_arg, typ_arg in zip(first_type.args, typ.args): + new_constraint_list.append(SameTypeConstraint( + first_type_arg, typ_arg + )) + + continue + + if typ != first_type: + return Error(f'{typ:s} must be {first_type:s} instead') + + if new_constraint_list: + # If this happens, make CheckResult a class that can have both + assert not placeholders, 'Cannot (yet) return both new placeholders and new constraints' + + return new_constraint_list + + if not placeholders: + return None + + for typ in placeholders: + typ.resolve_as = first_type + + return { + typ: first_type + for typ in placeholders + } + + def human_readable(self) -> HumanReadableRet: + return ( + ' == '.join('{t' + str(idx) + '}' for idx in range(len(self.type_list))), + { + 't' + str(idx): typ + for idx, typ in enumerate(self.type_list) + }, + ) + + def __repr__(self) -> str: + args = ', '.join(repr(x) for x in self.type_list) + + return f'SameTypeConstraint({args}, comment={repr(self.comment)})' + +class IntegerCompareConstraint(ConstraintBase): + """ + Verifies that the given IntType3 are in order (<=) + """ + __slots__ = ('int_type3_list', ) + + int_type3_list: List[types.IntType3] + + def __init__(self, *int_type3: types.IntType3, comment: Optional[str] = None) -> None: + super().__init__(comment=comment) + + assert len(int_type3) > 1 + self.int_type3_list = [*int_type3] + + def check(self) -> CheckResult: + val_list = [x.value for x in self.int_type3_list] + + prev_val = val_list.pop(0) + for next_val in val_list: + if prev_val > next_val: + return Error(f'{prev_val} must be less or equal than {next_val}') + + prev_val = next_val + + return None + + def human_readable(self) -> HumanReadableRet: + return ( + ' <= '.join('{t' + str(idx) + '}' for idx in range(len(self.int_type3_list))), + { + 't' + str(idx): typ + for idx, typ in enumerate(self.int_type3_list) + }, + ) + + def __repr__(self) -> str: + args = ', '.join(repr(x) for x in self.int_type3_list) + + return f'IntegerCompareConstraint({args}, comment={repr(self.comment)})' + +class CastableConstraint(ConstraintBase): + """ + A type can be cast to another type + """ + __slots__ = ('from_type3', 'to_type3', ) + + from_type3: types.Type3OrPlaceholder + to_type3: types.Type3OrPlaceholder + + def __init__(self, from_type3: types.Type3OrPlaceholder, to_type3: types.Type3OrPlaceholder, comment: Optional[str] = None) -> None: + super().__init__(comment=comment) + + self.from_type3 = from_type3 + self.to_type3 = to_type3 + + def check(self) -> CheckResult: + ftyp = self.from_type3 + if isinstance(ftyp, types.PlaceholderForType) and ftyp.resolve_as is not None: + ftyp = ftyp.resolve_as + + ttyp = self.to_type3 + if isinstance(ttyp, types.PlaceholderForType) and ttyp.resolve_as is not None: + ttyp = ttyp.resolve_as + + if isinstance(ftyp, types.PlaceholderForType) or isinstance(ttyp, types.PlaceholderForType): + return RequireTypeSubstitutes() + + if ftyp is types.u8 and ttyp is types.u32: + return None + + return Error(f'Cannot cast {ftyp.name} to {ttyp.name}') + + def human_readable(self) -> HumanReadableRet: + return ( + '{to_type3}({from_type3})', + { + 'to_type3': self.to_type3, + 'from_type3': self.from_type3, + }, + ) + + def __repr__(self) -> str: + return f'CastableConstraint({repr(self.from_type3)}, {repr(self.to_type3)}, comment={repr(self.comment)})' + +class MustImplementTypeClassConstraint(ConstraintBase): + """ + A type must implement a given type class + """ + __slots__ = ('type_class3', 'type3', ) + + type_class3: str + type3: types.Type3OrPlaceholder + + DATA = { + 'u8': {'BitWiseOperation', 'BasicMathOperation', 'EqualComparison', 'StrictPartialOrder'}, + 'u32': {'BitWiseOperation', 'BasicMathOperation', 'EqualComparison', 'StrictPartialOrder'}, + 'u64': {'BitWiseOperation', 'BasicMathOperation', 'EqualComparison', 'StrictPartialOrder'}, + 'i32': {'BasicMathOperation', 'EqualComparison', 'StrictPartialOrder'}, + 'i64': {'BasicMathOperation', 'EqualComparison', 'StrictPartialOrder'}, + 'bytes': {'Foldable', 'Sized'}, + 'f32': {'BasicMathOperation', 'FloatingPoint'}, + 'f64': {'BasicMathOperation', 'FloatingPoint'}, + } + + def __init__(self, type_class3: str, type3: types.Type3OrPlaceholder, comment: Optional[str] = None) -> None: + super().__init__(comment=comment) + + self.type_class3 = type_class3 + self.type3 = type3 + + def check(self) -> CheckResult: + typ = self.type3 + if isinstance(typ, types.PlaceholderForType) and typ.resolve_as is not None: + typ = typ.resolve_as + + if isinstance(typ, types.PlaceholderForType): + return RequireTypeSubstitutes() + + if self.type_class3 in self.__class__.DATA.get(typ.name, set()): + return None + + return Error(f'{typ.name} does not implement the {self.type_class3} type class') + + def human_readable(self) -> HumanReadableRet: + return ( + '{type3} derives {type_class3}', + { + 'type_class3': self.type_class3, + 'type3': self.type3, + }, + ) + + def __repr__(self) -> str: + return f'MustImplementTypeClassConstraint({repr(self.type_class3)}, {repr(self.type3)}, comment={repr(self.comment)})' + +class LiteralFitsConstraint(ConstraintBase): + """ + A literal value fits a given type + """ + __slots__ = ('type3', 'literal', ) + + type3: types.Type3OrPlaceholder + literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct] + + def __init__( + self, + type3: types.Type3OrPlaceholder, + literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct], + comment: Optional[str] = None, + ) -> None: + super().__init__(comment=comment) + + self.type3 = type3 + self.literal = literal + + def check(self) -> CheckResult: + int_table: Dict[str, Tuple[int, bool]] = { + 'u8': (1, False), + 'u32': (4, False), + 'u64': (8, False), + 'i8': (1, True), + 'i32': (4, True), + 'i64': (8, True), + } + + float_table: Dict[str, None] = { + 'f32': None, + 'f64': None, + } + + if isinstance(self.type3, types.PlaceholderForType): + if self.type3.resolve_as is None: + return RequireTypeSubstitutes() + + self.type3 = self.type3.resolve_as + + if self.type3.name in int_table: + bts, sgn = int_table[self.type3.name] + + if isinstance(self.literal.value, int): + try: + self.literal.value.to_bytes(bts, 'big', signed=sgn) + except OverflowError: + return Error(f'Must fit in {bts} byte(s)') # FIXME: Add line information + + return None + + return Error('Must be integer') # FIXME: Add line information + + if self.type3.name in float_table: + _ = float_table[self.type3.name] + + if isinstance(self.literal.value, float): + # FIXME: Bit check + + return None + + return Error('Must be real') # FIXME: Add line information + + res: NewConstraintList + + if isinstance(self.type3, types.AppliedType3): + if self.type3.base == types.tuple: + if not isinstance(self.literal, ourlang.ConstantTuple): + return Error('Must be tuple') + + if len(self.type3.args) != len(self.literal.value): + return Error('Tuple element count mismatch') + + res = [] + + res.extend( + LiteralFitsConstraint(x, y) + for x, y in zip(self.type3.args, self.literal.value) + ) + res.extend( + SameTypeConstraint(x, y.type3) + for x, y in zip(self.type3.args, self.literal.value) + ) + + return res + + if self.type3.base == types.static_array: + if not isinstance(self.literal, ourlang.ConstantTuple): + return Error('Must be tuple') + + assert 2 == len(self.type3.args) + assert isinstance(self.type3.args[1], types.IntType3) + + if self.type3.args[1].value != len(self.literal.value): + return Error('Member count mismatch') + + res = [] + + res.extend( + LiteralFitsConstraint(self.type3.args[0], y) + for y in self.literal.value + ) + res.extend( + SameTypeConstraint(self.type3.args[0], y.type3) + for y in self.literal.value + ) + + return res + + if isinstance(self.type3, types.StructType3): + if not isinstance(self.literal, ourlang.ConstantStruct): + return Error('Must be struct') + + if self.literal.struct_name != self.type3.name: + return Error('Struct mismatch') + + + if len(self.type3.members) != len(self.literal.value): + return Error('Struct element count mismatch') + + res = [] + + res.extend( + LiteralFitsConstraint(x, y) + for x, y in zip(self.type3.members.values(), self.literal.value) + ) + res.extend( + SameTypeConstraint(x_t, y.type3, comment=f'{self.literal.struct_name}.{x_n}') + for (x_n, x_t, ), y in zip(self.type3.members.items(), self.literal.value) + ) + + return res + + raise NotImplementedError(self.type3, self.literal) + + def human_readable(self) -> HumanReadableRet: + return ( + '{literal} : {type3}', + { + 'literal': self.literal, + 'type3': self.type3, + }, + ) + + def __repr__(self) -> str: + return f'LiteralFitsConstraint({repr(self.type3)}, {repr(self.literal)}, comment={repr(self.comment)})' + +class CanBeSubscriptedConstraint(ConstraintBase): + """ + A value that is subscipted, i.e. a[0] (tuple) or a[b] (static array) + """ + __slots__ = ('ret_type3', 'type3', 'index', 'index_type3', ) + + ret_type3: types.Type3OrPlaceholder + type3: types.Type3OrPlaceholder + index: ourlang.Expression + index_type3: types.Type3OrPlaceholder + + def __init__(self, ret_type3: types.Type3OrPlaceholder, type3: types.Type3OrPlaceholder, index: ourlang.Expression, comment: Optional[str] = None) -> None: + super().__init__(comment=comment) + + self.ret_type3 = ret_type3 + self.type3 = type3 + self.index = index + self.index_type3 = index.type3 + + def check(self) -> CheckResult: + if isinstance(self.type3, types.PlaceholderForType): + if self.type3.resolve_as is None: + return RequireTypeSubstitutes() + + self.type3 = self.type3.resolve_as + + if isinstance(self.type3, types.AppliedType3): + if self.type3.base == types.static_array: + result: List[ConstraintBase] = [ + SameTypeConstraint(types.u32, self.index_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'), + SameTypeConstraint(self.type3.args[0], self.ret_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'), + ] + + if isinstance(self.index, ourlang.ConstantPrimitive): + assert isinstance(self.index.value, int) + assert isinstance(self.type3.args[1], types.IntType3) + + result.append( + IntegerCompareConstraint( + types.IntType3(0), types.IntType3(self.index.value), types.IntType3(self.type3.args[1].value - 1), + comment='Subscript static array must fit the size of the array' + ) + ) + + return result + + if self.type3.base == types.tuple: + if not isinstance(self.index, ourlang.ConstantPrimitive): + return Error('Must index with literal') + + if not isinstance(self.index.value, int): + return Error('Must index with integer literal') + + if self.index.value < 0 or len(self.type3.args) <= self.index.value: + return Error('Tuple index out of range') + + return [ + SameTypeConstraint(types.u32, self.index_type3, comment=f'Tuple subscript index {self.index.value}'), + SameTypeConstraint(self.type3.args[self.index.value], self.ret_type3, comment=f'Tuple subscript index {self.index.value}'), + ] + + if self.type3 is types.bytes: + return [ + SameTypeConstraint(types.u32, self.index_type3, comment='([]) :: bytes -> u32 -> u8'), + SameTypeConstraint(types.u8, self.ret_type3, comment='([]) :: bytes -> u32 -> u8'), + ] + + if self.type3.name in types.LOOKUP_TABLE: + return Error(f'{self.type3.name} cannot be subscripted') + + raise NotImplementedError(self.type3) + + def human_readable(self) -> HumanReadableRet: + return ( + '{type3}[{index}]', + { + 'type3': self.type3, + 'index': self.index, + }, + ) + + def __repr__(self) -> str: + return f'CanBeSubscriptedConstraint({repr(self.type3)}, {repr(self.index)}, comment={repr(self.comment)})' diff --git a/phasm/type3/constraintsgenerator.py b/phasm/type3/constraintsgenerator.py new file mode 100644 index 0000000..61e7ec9 --- /dev/null +++ b/phasm/type3/constraintsgenerator.py @@ -0,0 +1,219 @@ +""" +This module generates the typing constraints for Phasm. + +The constraints solver can then try to resolve all constraints. +""" +from typing import Generator, List + +from .. import ourlang + +from .constraints import ( + Context, + + ConstraintBase, + CastableConstraint, CanBeSubscriptedConstraint, + LiteralFitsConstraint, MustImplementTypeClassConstraint, SameTypeConstraint, +) + +from . import types as type3types + +ConstraintGenerator = Generator[ConstraintBase, None, None] + +def phasm_type3_generate_constraints(inp: ourlang.Module) -> List[ConstraintBase]: + ctx = Context() + + return [*module(ctx, inp)] + +def constant(ctx: Context, inp: ourlang.Constant) -> ConstraintGenerator: + if isinstance(inp, (ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct)): + yield LiteralFitsConstraint(inp.type3, inp) + return + + raise NotImplementedError(constant, inp) + +def expression(ctx: Context, inp: ourlang.Expression) -> ConstraintGenerator: + if isinstance(inp, ourlang.Constant): + yield from constant(ctx, inp) + return + + if isinstance(inp, ourlang.VariableReference): + yield SameTypeConstraint(inp.variable.type3, inp.type3, + comment=f'typeOf("{inp.variable.name}") == typeOf({inp.variable.name})') + return + + if isinstance(inp, ourlang.UnaryOp): + if 'len' == inp.operator: + yield from expression(ctx, inp.right) + yield MustImplementTypeClassConstraint('Sized', inp.right.type3) + yield SameTypeConstraint(type3types.u32, inp.type3, comment='len :: Sized a => a -> u32') + return + + if 'sqrt' == inp.operator: + yield from expression(ctx, inp.right) + yield MustImplementTypeClassConstraint('FloatingPoint', inp.right.type3) + yield SameTypeConstraint(inp.right.type3, inp.type3, comment='sqrt :: FloatingPoint a => a -> a') + return + + if 'cast' == inp.operator: + yield from expression(ctx, inp.right) + yield CastableConstraint(inp.right.type3, inp.type3) + return + + raise NotImplementedError(expression, inp, inp.operator) + + if isinstance(inp, ourlang.BinaryOp): + if inp.operator in ('|', '&', '^', ): + yield from expression(ctx, inp.left) + yield from expression(ctx, inp.right) + + yield MustImplementTypeClassConstraint('BitWiseOperation', inp.left.type3) + yield SameTypeConstraint(inp.left.type3, inp.right.type3, inp.type3, + comment=f'({inp.operator}) :: a -> a -> a') + return + + if inp.operator in ('>>', '<<', ): + yield from expression(ctx, inp.left) + yield from expression(ctx, inp.right) + + yield MustImplementTypeClassConstraint('BitWiseOperation', inp.left.type3) + yield SameTypeConstraint(inp.left.type3, inp.right.type3, inp.type3, + comment=f'({inp.operator}) :: a -> a -> a') + return + + if inp.operator in ('+', '-', '*', '/', ): + yield from expression(ctx, inp.left) + yield from expression(ctx, inp.right) + + yield MustImplementTypeClassConstraint('BasicMathOperation', inp.left.type3) + yield SameTypeConstraint(inp.left.type3, inp.right.type3, inp.type3, + comment=f'({inp.operator}) :: a -> a -> a') + return + + if inp.operator == '==': + yield from expression(ctx, inp.left) + yield from expression(ctx, inp.right) + + yield MustImplementTypeClassConstraint('EqualComparison', inp.left.type3) + yield SameTypeConstraint(inp.left.type3, inp.right.type3, + comment=f'({inp.operator}) :: a -> a -> bool') + yield SameTypeConstraint(inp.type3, type3types.bool_, + comment=f'({inp.operator}) :: a -> a -> bool') + return + + if inp.operator in ('<', '>'): + yield from expression(ctx, inp.left) + yield from expression(ctx, inp.right) + + yield MustImplementTypeClassConstraint('StrictPartialOrder', inp.left.type3) + yield SameTypeConstraint(inp.left.type3, inp.right.type3, + comment=f'({inp.operator}) :: a -> a -> bool') + yield SameTypeConstraint(inp.type3, type3types.bool_, + comment=f'({inp.operator}) :: a -> a -> bool') + return + + raise NotImplementedError(expression, inp) + + if isinstance(inp, ourlang.FunctionCall): + yield SameTypeConstraint(inp.function.returns_type3, inp.type3, + comment=f'The type of a function call to {inp.function.name} is the same as the type that the function returns') + + assert len(inp.arguments) == len(inp.function.posonlyargs) # FIXME: Make this a Constraint + + for fun_arg, call_arg in zip(inp.function.posonlyargs, inp.arguments): + yield from expression(ctx, call_arg) + yield SameTypeConstraint(fun_arg.type3, call_arg.type3, + comment=f'The type of the value passed to argument {fun_arg.name} of function {inp.function.name} should match the type of that argument') + + return + + if isinstance(inp, ourlang.TupleInstantiation): + r_type = [] + for arg in inp.elements: + yield from expression(ctx, arg) + r_type.append(arg.type3) + + yield SameTypeConstraint( + inp.type3, + type3types.AppliedType3(type3types.tuple, r_type), + comment=f'The type of a tuple is a combination of its members' + ) + + return + + if isinstance(inp, ourlang.Subscript): + yield from expression(ctx, inp.varref) + yield from expression(ctx, inp.index) + + yield CanBeSubscriptedConstraint(inp.type3, inp.varref.type3, inp.index) + return + + if isinstance(inp, ourlang.AccessStructMember): + yield from expression(ctx, inp.varref) + yield SameTypeConstraint(inp.struct_type3.members[inp.member], inp.type3, + comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}') + return + + if isinstance(inp, ourlang.Fold): + yield from expression(ctx, inp.base) + yield from expression(ctx, inp.iter) + + yield SameTypeConstraint(inp.func.posonlyargs[0].type3, inp.func.returns_type3, inp.base.type3, inp.type3, + comment='foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b') + yield MustImplementTypeClassConstraint('Foldable', inp.iter.type3) + + return + + raise NotImplementedError(expression, inp) + +def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator: + yield from expression(ctx, inp.value) + + yield SameTypeConstraint(fun.returns_type3, inp.value.type3, + comment=f'The type of the value returned from function {fun.name} should match its return type') + +def statement_if(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementIf) -> ConstraintGenerator: + yield from expression(ctx, inp.test) + + yield SameTypeConstraint(inp.test.type3, type3types.bool_, + comment=f'Must pass a boolean expression to if') + + for stmt in inp.statements: + yield from statement(ctx, fun, stmt) + + for stmt in inp.else_statements: + yield from statement(ctx, fun, stmt) + +def statement(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement) -> ConstraintGenerator: + if isinstance(inp, ourlang.StatementReturn): + yield from statement_return(ctx, fun, inp) + return + + if isinstance(inp, ourlang.StatementIf): + yield from statement_if(ctx, fun, inp) + return + + raise NotImplementedError(statement, fun, inp) + +def function(ctx: Context, inp: ourlang.Function) -> ConstraintGenerator: + assert not inp.imported + + if isinstance(inp, ourlang.StructConstructor): + return + + for stmt in inp.statements: + yield from statement(ctx, inp, stmt) + +def module_constant_def(ctx: Context, inp: ourlang.ModuleConstantDef) -> ConstraintGenerator: + yield from constant(ctx, inp.constant) + yield SameTypeConstraint(inp.type3, inp.constant.type3, + comment=f'The type of the value for module constant definition {inp.name} should match the type of that constant') + +def module(ctx: Context, inp: ourlang.Module) -> ConstraintGenerator: + for cdef in inp.constant_defs.values(): + yield from module_constant_def(ctx, cdef) + + for func in inp.functions.values(): + if func.imported: + continue + + yield from function(ctx, func) diff --git a/phasm/type3/entry.py b/phasm/type3/entry.py new file mode 100644 index 0000000..c054077 --- /dev/null +++ b/phasm/type3/entry.py @@ -0,0 +1,154 @@ +""" +Entry point to the type3 system +""" +from typing import Any, Dict, List, Set + +from .. import codestyle +from .. import ourlang + +from .constraints import ConstraintBase, Error, RequireTypeSubstitutes, SameTypeConstraint, SubstitutionMap +from .constraintsgenerator import phasm_type3_generate_constraints +from .types import AppliedType3, IntType3, PlaceholderForType, PrimitiveType3, StructType3, Type3, Type3OrPlaceholder + +MAX_RESTACK_COUNT = 100 + +class Type3Exception(BaseException): + """ + Thrown when the Type3 system detects constraints that do not hold + """ + +def phasm_type3(inp: ourlang.Module, verbose: bool = False) -> None: + constraint_list = phasm_type3_generate_constraints(inp) + assert constraint_list + + placeholder_substitutes: Dict[PlaceholderForType, Type3] = {} + placeholder_id_map: Dict[int, str] = {} + + error_list: List[Error] = [] + for _ in range(MAX_RESTACK_COUNT): + if verbose: + print() + print_constraint_list(placeholder_id_map, constraint_list, placeholder_substitutes) + + old_constraint_ids = {id(x) for x in constraint_list} + old_placeholder_substitutes_len = len(placeholder_substitutes) + + new_constraint_list = [] + for constraint in constraint_list: + check_result = constraint.check() + if check_result is None: + if verbose: + print_constraint(placeholder_id_map, constraint) + print('-> Constraint checks out') + continue + + if isinstance(check_result, dict): + placeholder_substitutes.update(check_result) + + if verbose: + print_constraint(placeholder_id_map, constraint) + print('-> Constraint checks out, and gave us new information') + continue + + if isinstance(check_result, Error): + error_list.append(check_result) + if verbose: + print_constraint(placeholder_id_map, constraint) + print('-> Got an error') + continue + + if isinstance(check_result, RequireTypeSubstitutes): + new_constraint_list.append(constraint) + + if verbose: + print_constraint(placeholder_id_map, constraint) + print('-> Back on the todo list') + continue + + if isinstance(check_result, list): + new_constraint_list.extend(check_result) + + if verbose: + print_constraint(placeholder_id_map, constraint) + print(f'-> Resulted in {len(check_result)} new constraints') + continue + + raise NotImplementedError(constraint, check_result) + + if not new_constraint_list: + constraint_list = new_constraint_list + break + + # Infinite loop detection + new_constraint_ids = {id(x) for x in new_constraint_list} + new_placeholder_substitutes_len = len(placeholder_substitutes) + + if old_constraint_ids == new_constraint_ids and old_placeholder_substitutes_len == new_placeholder_substitutes_len: + if error_list: + raise Type3Exception(error_list) + + raise Exception('Cannot type this program - not enough information') + + constraint_list = new_constraint_list + + if constraint_list: + raise Exception(f'Cannot type this program - tried {MAX_RESTACK_COUNT} iterations') + + if error_list: + raise Type3Exception(error_list) + + # FIXME: This doesn't work with e.g. `:: [a] -> a`, as the placeholder is inside a type + for plh, typ in placeholder_substitutes.items(): + for expr in plh.update_on_substitution: + assert expr.type3 is plh + + expr.type3 = typ + +def print_constraint(placeholder_id_map: Dict[int, str], constraint: ConstraintBase) -> None: + txt, fmt = constraint.human_readable() + act_fmt: Dict[str, str] = {} + for fmt_key, fmt_val in fmt.items(): + if isinstance(fmt_val, ourlang.Expression): + fmt_val = codestyle.expression(fmt_val) + + if isinstance(fmt_val, Type3) or isinstance(fmt_val, PlaceholderForType): + fmt_val = get_printable_type_name(fmt_val, placeholder_id_map) + + if not isinstance(fmt_val, str): + fmt_val = repr(fmt_val) + + act_fmt[fmt_key] = fmt_val + + if constraint.comment is not None: + print('- ' + txt.format(**act_fmt).ljust(40) + '; ' + constraint.comment) + else: + print('- ' + txt.format(**act_fmt)) + +def get_printable_type_name(inp: Type3OrPlaceholder, placeholder_id_map: Dict[int, str]) -> str: + if isinstance(inp, (PrimitiveType3, StructType3, IntType3, )): + return inp.name + + if isinstance(inp, PlaceholderForType): + placeholder_id = id(inp) + if placeholder_id not in placeholder_id_map: + placeholder_id_map[placeholder_id] = 'T' + str(len(placeholder_id_map) + 1) + return placeholder_id_map[placeholder_id] + + if isinstance(inp, AppliedType3): + return ( + get_printable_type_name(inp.base, placeholder_id_map) + + ' (' + + ') ('.join(get_printable_type_name(x, placeholder_id_map) for x in inp.args) + + ')' + ) + + raise NotImplementedError(inp) + +def print_constraint_list(placeholder_id_map: Dict[int, str], constraint_list: List[ConstraintBase], placeholder_substitutes: SubstitutionMap) -> None: + print('=== v type3 constraint_list v === ') + for psk, psv in placeholder_substitutes.items(): + print_constraint(placeholder_id_map, SameTypeConstraint(psk, psv, comment='Deduced type')) + + for constraint in constraint_list: + print_constraint(placeholder_id_map, constraint) + print('=== ^ type3 constraint_list ^ === ') diff --git a/phasm/type3/types.py b/phasm/type3/types.py new file mode 100644 index 0000000..bd10abf --- /dev/null +++ b/phasm/type3/types.py @@ -0,0 +1,322 @@ +""" +Contains the final types for use in Phasm + +These are actual, instantiated types; not the abstract types that the +constraint generator works with. +""" +from typing import Any, Dict, Iterable, List, Optional, Protocol, Union + +TYPE3_ASSERTION_ERROR = 'You must call phasm_type3 after calling phasm_parse before you can call any other method' + +class ExpressionProtocol(Protocol): + """ + A protocol for classes that should be updated on substitution + """ + + type3: 'Type3OrPlaceholder' + """ + The type to update + """ + +class Type3: + """ + Base class for the type3 types + """ + __slots__ = ('name', ) + + name: str + """ + The name of the string, as parsed and outputted by codestyle. + """ + + def __init__(self, name: str) -> None: + self.name = name + + def __repr__(self) -> str: + return f'Type3("{self.name}")' + + def __str__(self) -> str: + return self.name + + def __format__(self, format_spec: str) -> str: + if format_spec != 's': + raise TypeError(f'unsupported format string passed to Type3.__format__: {format_spec}') + + return str(self) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, PlaceholderForType): + return False + + if not isinstance(other, Type3): + raise NotImplementedError + + return self is other + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def __hash__(self) -> int: + raise NotImplementedError + + def __bool__(self) -> bool: + raise NotImplementedError + +class PrimitiveType3(Type3): + """ + Intermediate class to tell primitive types from others + """ + + __slots__ = () + +class IntType3(Type3): + """ + Sometimes you can have an int as type, e.g. when using static arrays + """ + + __slots__ = ('value', ) + + value: int + + def __init__(self, value: int) -> None: + super().__init__(str(value)) + + assert 0 <= value + self.value = value + + def __eq__(self, other: Any) -> bool: + if isinstance(other, IntType3): + return self.value == other.value + + if isinstance(other, Type3): + return False + + raise NotImplementedError + +class PlaceholderForType: + """ + A placeholder type, for when we don't know the final type yet + """ + __slots__ = ('update_on_substitution', 'resolve_as', ) + + update_on_substitution: List[ExpressionProtocol] + resolve_as: Optional[Type3] + + def __init__(self, update_on_substitution: Iterable[ExpressionProtocol]) -> None: + self.update_on_substitution = [*update_on_substitution] + self.resolve_as = None + + def __repr__(self) -> str: + uos = ', '.join(repr(x) for x in self.update_on_substitution) + + return f'PlaceholderForType({id(self)}, [{uos}])' + + def __str__(self) -> str: + return f'PhFT_{id(self)}' + + def __format__(self, format_spec: str) -> str: + if format_spec != 's': + raise TypeError('unsupported format string passed to Type3.__format__') + + return str(self) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Type3): + return False + + if not isinstance(other, PlaceholderForType): + raise NotImplementedError + + return self is other + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def __hash__(self) -> int: + return 0 # Valid but performs badly + + def __bool__(self) -> bool: + raise NotImplementedError + +Type3OrPlaceholder = Union[Type3, PlaceholderForType] + +class AppliedType3(Type3): + """ + A Type3 that has been applied to another type + """ + __slots__ = ('base', 'args', ) + + base: PrimitiveType3 + """ + The base type + """ + + args: List[Type3OrPlaceholder] + """ + The applied types (or placeholders there for) + """ + + def __init__(self, base: PrimitiveType3, args: Iterable[Type3OrPlaceholder]) -> None: + args = [*args] + assert args, 'Must at least one argument' + + super().__init__( + base.name + + ' (' + + ') ('.join(str(x) for x in args) # FIXME: Do we need to redo the name on substitution? + + ')' + ) + + self.base = base + self.args = args + + @property + def has_placeholders(self) -> bool: + return any( + isinstance(x, PlaceholderForType) + for x in self.args + ) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Type3): + raise NotImplementedError + + if not isinstance(other, AppliedType3): + return False + + return ( + self.base == other.base + and len(self.args) == len(other.args) + and all( + s == x + for s, x in zip(self.args, other.args) + ) + ) + + def __repr__(self) -> str: + return f'AppliedType3({repr(self.base)}, {repr(self.args)})' + +class StructType3(Type3): + """ + A Type3 struct with named members + """ + __slots__ = ('name', 'members', ) + + name: str + """ + The structs fully qualified name + """ + + members: Dict[str, Type3] + """ + The struct's field definitions + """ + + def __init__(self, name: str, members: Dict[str, Type3]) -> None: + super().__init__(name) + + self.name = name + self.members = dict(members) + + def __repr__(self) -> str: + return f'StructType3(repr({self.name}), repr({self.members}))' + +none = PrimitiveType3('none') +""" +The none type, for when functions simply don't return anything. e.g., IO(). +""" + +bool_ = PrimitiveType3('bool') +""" +The bool type, either True or False +""" + +u8 = PrimitiveType3('u8') +""" +The unsigned 8-bit integer type. + +Operations on variables employ modular arithmetic, with modulus 2^8. +""" + +u32 = PrimitiveType3('u32') +""" +The unsigned 32-bit integer type. + +Operations on variables employ modular arithmetic, with modulus 2^32. +""" + +u64 = PrimitiveType3('u64') +""" +The unsigned 64-bit integer type. + +Operations on variables employ modular arithmetic, with modulus 2^64. +""" + +i8 = PrimitiveType3('i8') +""" +The signed 8-bit integer type. + +Operations on variables employ modular arithmetic, with modulus 2^8, but +with the middel point being 0. +""" + +i32 = PrimitiveType3('i32') +""" +The unsigned 32-bit integer type. + +Operations on variables employ modular arithmetic, with modulus 2^32, but +with the middel point being 0. +""" + +i64 = PrimitiveType3('i64') +""" +The unsigned 64-bit integer type. + +Operations on variables employ modular arithmetic, with modulus 2^64, but +with the middel point being 0. +""" + +f32 = PrimitiveType3('f32') +""" +A 32-bits IEEE 754 float, of 32 bits width. +""" + +f64 = PrimitiveType3('f64') +""" +A 32-bits IEEE 754 float, of 64 bits width. +""" + +bytes = PrimitiveType3('bytes') +""" +This is a runtime-determined length piece of memory that can be indexed at runtime. +""" + +static_array = PrimitiveType3('static_array') +""" +This is a fixed length piece of memory that can be indexed at runtime. + +It should be applied with one argument. It has a runtime-dynamic length +of the same type repeated. +""" + +tuple = PrimitiveType3('tuple') # pylint: disable=W0622 +""" +This is a fixed length piece of memory. + +It should be applied with zero or more arguments. It has a compile time +determined length, and each argument can be different. +""" + +LOOKUP_TABLE: Dict[str, Type3] = { + 'none': none, + 'bool': bool_, + 'u8': u8, + 'u32': u32, + 'u64': u64, + 'i8': i8, + 'i32': i32, + 'i64': i64, + 'f32': f32, + 'f64': f64, + 'bytes': bytes, +} diff --git a/phasm/typing.py b/phasm/typing.py deleted file mode 100644 index e56f7a9..0000000 --- a/phasm/typing.py +++ /dev/null @@ -1,202 +0,0 @@ -""" -The phasm type system -""" -from typing import Optional, List - -class TypeBase: - """ - TypeBase base class - """ - __slots__ = () - - def alloc_size(self) -> int: - """ - When allocating this type in memory, how many bytes do we need to reserve? - """ - raise NotImplementedError(self, 'alloc_size') - -class TypeNone(TypeBase): - """ - The None (or Void) type - """ - __slots__ = () - -class TypeBool(TypeBase): - """ - The boolean type - """ - __slots__ = () - -class TypeUInt8(TypeBase): - """ - The Integer type, unsigned and 8 bits wide - - Note that under the hood we need to use i32 to represent - these values in expressions. So we need to add some operations - to make sure the math checks out. - - So while this does save bytes in memory, it may not actually - speed up or improve your code. - """ - __slots__ = () - - def alloc_size(self) -> int: - return 4 # Int32 under the hood - -class TypeUInt32(TypeBase): - """ - The Integer type, unsigned and 32 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 4 - -class TypeUInt64(TypeBase): - """ - The Integer type, unsigned and 64 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 8 - -class TypeInt32(TypeBase): - """ - The Integer type, signed and 32 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 4 - -class TypeInt64(TypeBase): - """ - The Integer type, signed and 64 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 8 - -class TypeFloat32(TypeBase): - """ - The Float type, 32 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 4 - -class TypeFloat64(TypeBase): - """ - The Float type, 64 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 8 - -class TypeBytes(TypeBase): - """ - The bytes type - """ - __slots__ = () - -class TypeTupleMember: - """ - Represents a tuple member - """ - def __init__(self, idx: int, type_: TypeBase, offset: int) -> None: - self.idx = idx - self.type = type_ - self.offset = offset - -class TypeTuple(TypeBase): - """ - The tuple type - """ - __slots__ = ('members', ) - - members: List[TypeTupleMember] - - def __init__(self) -> None: - self.members = [] - - def render_internal_name(self) -> str: - """ - Generates an internal name for this tuple - """ - mems = '@'.join('?' for x in self.members) # FIXME: Should not be a questionmark - assert ' ' not in mems, 'Not implement yet: subtuples' - return f'tuple@{mems}' - - def alloc_size(self) -> int: - return sum( - x.type.alloc_size() - for x in self.members - ) - -class TypeStaticArrayMember: - """ - Represents a static array member - """ - def __init__(self, idx: int, offset: int) -> None: - self.idx = idx - self.offset = offset - -class TypeStaticArray(TypeBase): - """ - The static array type - """ - __slots__ = ('member_type', 'members', ) - - member_type: TypeBase - members: List[TypeStaticArrayMember] - - def __init__(self, member_type: TypeBase) -> None: - self.member_type = member_type - self.members = [] - - def alloc_size(self) -> int: - return self.member_type.alloc_size() * len(self.members) - -class TypeStructMember: - """ - Represents a struct member - """ - def __init__(self, name: str, type_: TypeBase, offset: int) -> None: - self.name = name - self.type = type_ - self.offset = offset - -class TypeStruct(TypeBase): - """ - A struct has named properties - """ - __slots__ = ('name', 'lineno', 'members', ) - - name: str - lineno: int - members: List[TypeStructMember] - - def __init__(self, name: str, lineno: int) -> None: - self.name = name - self.lineno = lineno - self.members = [] - - def get_member(self, name: str) -> Optional[TypeStructMember]: - """ - Returns a member by name - """ - for mem in self.members: - if mem.name == name: - return mem - - return None - - def alloc_size(self) -> int: - return sum( - x.type.alloc_size() - for x in self.members - ) diff --git a/phasm/wasmeasy.py b/phasm/wasmeasy.py deleted file mode 100644 index d0cf358..0000000 --- a/phasm/wasmeasy.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -Helper functions to quickly generate WASM code -""" -from typing import Any, Dict, List, Optional, Type - -import functools - -from . import wasm - -#pylint: disable=C0103,C0115,C0116,R0201,R0902 - -class Prefix_inn_fnn: - def __init__(self, prefix: str) -> None: - self.prefix = prefix - - # 6.5.5. Memory Instructions - self.load = functools.partial(wasm.Statement, f'{self.prefix}.load') - self.store = functools.partial(wasm.Statement, f'{self.prefix}.store') - - # 6.5.6. Numeric Instructions - self.clz = functools.partial(wasm.Statement, f'{self.prefix}.clz') - self.ctz = functools.partial(wasm.Statement, f'{self.prefix}.ctz') - self.popcnt = functools.partial(wasm.Statement, f'{self.prefix}.popcnt') - self.add = functools.partial(wasm.Statement, f'{self.prefix}.add') - self.sub = functools.partial(wasm.Statement, f'{self.prefix}.sub') - self.mul = functools.partial(wasm.Statement, f'{self.prefix}.mul') - self.div_s = functools.partial(wasm.Statement, f'{self.prefix}.div_s') - self.div_u = functools.partial(wasm.Statement, f'{self.prefix}.div_u') - self.rem_s = functools.partial(wasm.Statement, f'{self.prefix}.rem_s') - self.rem_u = functools.partial(wasm.Statement, f'{self.prefix}.rem_u') - self.and_ = functools.partial(wasm.Statement, f'{self.prefix}.and') - self.or_ = functools.partial(wasm.Statement, f'{self.prefix}.or') - self.xor = functools.partial(wasm.Statement, f'{self.prefix}.xor') - self.shl = functools.partial(wasm.Statement, f'{self.prefix}.shl') - self.shr_s = functools.partial(wasm.Statement, f'{self.prefix}.shr_s') - self.shr_u = functools.partial(wasm.Statement, f'{self.prefix}.shr_u') - self.rotl = functools.partial(wasm.Statement, f'{self.prefix}.rotl') - self.rotr = functools.partial(wasm.Statement, f'{self.prefix}.rotr') - - self.eqz = functools.partial(wasm.Statement, f'{self.prefix}.eqz') - self.eq = functools.partial(wasm.Statement, f'{self.prefix}.eq') - self.ne = functools.partial(wasm.Statement, f'{self.prefix}.ne') - self.lt_s = functools.partial(wasm.Statement, f'{self.prefix}.lt_s') - self.lt_u = functools.partial(wasm.Statement, f'{self.prefix}.lt_u') - self.gt_s = functools.partial(wasm.Statement, f'{self.prefix}.gt_s') - self.gt_u = functools.partial(wasm.Statement, f'{self.prefix}.gt_u') - self.le_s = functools.partial(wasm.Statement, f'{self.prefix}.le_s') - self.le_u = functools.partial(wasm.Statement, f'{self.prefix}.le_u') - self.ge_s = functools.partial(wasm.Statement, f'{self.prefix}.ge_s') - self.ge_u = functools.partial(wasm.Statement, f'{self.prefix}.ge_u') - - def const(self, value: int, comment: Optional[str] = None) -> wasm.Statement: - return wasm.Statement(f'{self.prefix}.const', f'0x{value:08x}', comment=comment) - -i32 = Prefix_inn_fnn('i32') -i64 = Prefix_inn_fnn('i64') - -class Block: - def __init__(self, start: str) -> None: - self.start = start - - def __call__(self, *statements: wasm.Statement) -> List[wasm.Statement]: - return [ - wasm.Statement('if'), - *statements, - wasm.Statement('end'), - ] - -if_ = Block('if') diff --git a/pylintrc b/pylintrc index 0591be3..f872f8d 100644 --- a/pylintrc +++ b/pylintrc @@ -1,5 +1,5 @@ [MASTER] -disable=C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223 +disable=C0103,C0122,R0902,R0903,R0911,R0912,R0913,R0915,R1710,W0223 max-line-length=180 @@ -7,4 +7,4 @@ max-line-length=180 good-names=g [tests] -disable=C0116, +disable=C0116,R0201 diff --git a/requirements.txt b/requirements.txt index d29e53c..64dce30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ -mypy==0.812 +mypy==0.991 pygments==2.12.0 -pylint==2.7.4 -pytest==6.2.2 +pylint==2.15.9 +pytest==7.2.0 pytest-integration==0.2.2 pywasm==1.0.7 pywasm3==0.5.0 wasmer==1.1.0 wasmer_compiler_cranelift==1.1.0 -wasmtime==0.36.0 +wasmtime==3.0.0 diff --git a/tests/integration/constants.py b/tests/integration/constants.py new file mode 100644 index 0000000..c2f8860 --- /dev/null +++ b/tests/integration/constants.py @@ -0,0 +1,16 @@ +""" +Constants for use in the tests +""" + +ALL_INT_TYPES = ['u8', 'u32', 'u64', 'i32', 'i64'] +COMPLETE_INT_TYPES = ['u32', 'u64', 'i32', 'i64'] + +ALL_FLOAT_TYPES = ['f32', 'f64'] +COMPLETE_FLOAT_TYPES = ALL_FLOAT_TYPES + +TYPE_MAP = { + **{x: int for x in ALL_INT_TYPES}, + **{x: float for x in ALL_FLOAT_TYPES}, +} + +COMPLETE_NUMERIC_TYPES = COMPLETE_INT_TYPES + COMPLETE_FLOAT_TYPES diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py index ca4c8a7..3cd682c 100644 --- a/tests/integration/helpers.py +++ b/tests/integration/helpers.py @@ -24,7 +24,7 @@ class Suite: def __init__(self, code_py): self.code_py = code_py - def run_code(self, *args, runtime='pywasm3', imports=None): + def run_code(self, *args, runtime='pywasm3', func_name='testEntry', imports=None): """ Compiles the given python code into wasm and then runs it @@ -74,7 +74,7 @@ class Suite: runner.interpreter_dump_memory(sys.stderr) result = SuiteResult() - result.returned_value = runner.call('testEntry', *wasm_args) + result.returned_value = runner.call(func_name, *wasm_args) write_header(sys.stderr, 'Memory (post run)') runner.interpreter_dump_memory(sys.stderr) diff --git a/tests/integration/runners.py b/tests/integration/runners.py index fd3a53e..91a29fd 100644 --- a/tests/integration/runners.py +++ b/tests/integration/runners.py @@ -13,6 +13,7 @@ import wasmtime from phasm.compiler import phasm_compile from phasm.parser import phasm_parse +from phasm.type3.entry import phasm_type3 from phasm import ourlang from phasm import wasm @@ -40,6 +41,7 @@ class RunnerBase: Parses the Phasm code into an AST """ self.phasm_ast = phasm_parse(self.phasm_code) + phasm_type3(self.phasm_ast, verbose=True) def compile_ast(self) -> None: """ @@ -160,7 +162,7 @@ class RunnerPywasm3(RunnerBase): memory = self.rtime.get_memory(0) for idx, byt in enumerate(data): - memory[offset + idx] = byt # type: ignore + memory[offset + idx] = byt def interpreter_read_memory(self, offset: int, length: int) -> bytes: memory = self.rtime.get_memory(0) diff --git a/tests/integration/test_constants.py b/tests/integration/test_constants.py deleted file mode 100644 index 19f0203..0000000 --- a/tests/integration/test_constants.py +++ /dev/null @@ -1,87 +0,0 @@ -import pytest - -from .helpers import Suite - -@pytest.mark.integration_test -def test_i32(): - code_py = """ -CONSTANT: i32 = 13 - -@exported -def testEntry() -> i32: - return CONSTANT * 5 -""" - - result = Suite(code_py).run_code() - - assert 65 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64', ]) -def test_tuple_1(type_): - code_py = f""" -CONSTANT: ({type_}, ) = (65, ) - -@exported -def testEntry() -> {type_}: - return helper(CONSTANT) - -def helper(vector: ({type_}, )) -> {type_}: - return vector[0] -""" - - result = Suite(code_py).run_code() - - assert 65 == result.returned_value - -@pytest.mark.integration_test -def test_tuple_6(): - code_py = """ -CONSTANT: (u8, u8, u32, u32, u64, u64, ) = (11, 22, 3333, 4444, 555555, 666666, ) - -@exported -def testEntry() -> u32: - return helper(CONSTANT) - -def helper(vector: (u8, u8, u32, u32, u64, u64, )) -> u32: - return vector[2] -""" - - result = Suite(code_py).run_code() - - assert 3333 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64', ]) -def test_static_array_1(type_): - code_py = f""" -CONSTANT: {type_}[1] = (65, ) - -@exported -def testEntry() -> {type_}: - return helper(CONSTANT) - -def helper(vector: {type_}[1]) -> {type_}: - return vector[0] -""" - - result = Suite(code_py).run_code() - - assert 65 == result.returned_value - -@pytest.mark.integration_test -def test_static_array_6(): - code_py = """ -CONSTANT: u32[6] = (11, 22, 3333, 4444, 555555, 666666, ) - -@exported -def testEntry() -> u32: - return helper(CONSTANT) - -def helper(vector: u32[6]) -> u32: - return vector[2] -""" - - result = Suite(code_py).run_code() - - assert 3333 == result.returned_value diff --git a/tests/integration/test_examples/__init__.py b/tests/integration/test_examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_examples/test_buffer.py b/tests/integration/test_examples/test_buffer.py new file mode 100644 index 0000000..b987c69 --- /dev/null +++ b/tests/integration/test_examples/test_buffer.py @@ -0,0 +1,19 @@ +import pytest + +from ..helpers import Suite + +@pytest.mark.slow_integration_test +def test_index(): + with open('examples/buffer.py', 'r', encoding='ASCII') as fil: + code_py = "\n" + fil.read() + + result = Suite(code_py).run_code(b'Hello, world!', 5, func_name='index', runtime='wasmtime') + assert 44 == result.returned_value + +@pytest.mark.slow_integration_test +def test_length(): + with open('examples/buffer.py', 'r', encoding='ASCII') as fil: + code_py = "\n" + fil.read() + + result = Suite(code_py).run_code(b'Hello, world!', func_name='length') + assert 13 == result.returned_value diff --git a/tests/integration/test_examples.py b/tests/integration/test_examples/test_crc32.py similarity index 93% rename from tests/integration/test_examples.py rename to tests/integration/test_examples/test_crc32.py index b3b278d..d9f74bf 100644 --- a/tests/integration/test_examples.py +++ b/tests/integration/test_examples/test_crc32.py @@ -3,9 +3,9 @@ import struct import pytest -from .helpers import Suite +from ..helpers import Suite -@pytest.mark.integration_test +@pytest.mark.slow_integration_test def test_crc32(): # FIXME: Stub # crc = 0xFFFFFFFF diff --git a/tests/integration/test_examples/test_fib.py b/tests/integration/test_examples/test_fib.py new file mode 100644 index 0000000..9474a20 --- /dev/null +++ b/tests/integration/test_examples/test_fib.py @@ -0,0 +1,12 @@ +import pytest + +from ..helpers import Suite + +@pytest.mark.slow_integration_test +def test_fib(): + with open('./examples/fib.py', 'r', encoding='UTF-8') as fil: + code_py = "\n" + fil.read() + + result = Suite(code_py).run_code() + + assert 102334155 == result.returned_value diff --git a/tests/integration/test_fib.py b/tests/integration/test_fib.py deleted file mode 100644 index 20e7e63..0000000 --- a/tests/integration/test_fib.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest - -from .helpers import Suite - -@pytest.mark.slow_integration_test -def test_fib(): - code_py = """ -def helper(n: i32, a: i32, b: i32) -> i32: - if n < 1: - return a + b - - return helper(n - 1, a + b, a) - -def fib(n: i32) -> i32: - if n == 0: - return 0 - - if n == 1: - return 1 - - return helper(n - 1, 0, 1) - -@exported -def testEntry() -> i32: - return fib(40) -""" - - result = Suite(code_py).run_code() - - assert 102334155 == result.returned_value diff --git a/tests/integration/test_helper.py b/tests/integration/test_helper.py deleted file mode 100644 index cb44021..0000000 --- a/tests/integration/test_helper.py +++ /dev/null @@ -1,70 +0,0 @@ -import io - -import pytest - -from pywasm import binary -from pywasm import Runtime - -from wasmer import wat2wasm - -def run(code_wat): - code_wasm = wat2wasm(code_wat) - module = binary.Module.from_reader(io.BytesIO(code_wasm)) - - runtime = Runtime(module, {}, {}) - - out_put = runtime.exec('testEntry', []) - return (runtime, out_put) - -@pytest.mark.parametrize('size,offset,exp_out_put', [ - ('32', 0, 0x3020100), - ('32', 1, 0x4030201), - ('64', 0, 0x706050403020100), - ('64', 2, 0x908070605040302), -]) -def test_i32_64_load(size, offset, exp_out_put): - code_wat = f""" - (module - (memory 1) - (data (memory 0) (i32.const 0) "\\00\\01\\02\\03\\04\\05\\06\\07\\08\\09\\10") - - (func (export "testEntry") (result i{size}) - i32.const {offset} - i{size}.load - return )) -""" - - (_, out_put) = run(code_wat) - assert exp_out_put == out_put - -def test_load_then_store(): - code_wat = """ - (module - (memory 1) - (data (memory 0) (i32.const 0) "\\04\\00\\00\\00") - - (func (export "testEntry") (result i32) (local $my_memory_value i32) - ;; Load i32 from address 0 - i32.const 0 - i32.load - - ;; Add 8 to the loaded value - i32.const 8 - i32.add - - local.set $my_memory_value - - ;; Store back to the memory - i32.const 0 - local.get $my_memory_value - i32.store - - ;; Return something - i32.const 9 - return )) -""" - (runtime, out_put) = run(code_wat) - - assert 9 == out_put - - assert (b'\x0c'+ b'\00' * 23) == runtime.store.mems[0].data[:24] diff --git a/tests/integration/test_lang/__init__.py b/tests/integration/test_lang/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_builtins.py b/tests/integration/test_lang/test_builtins.py similarity index 95% rename from tests/integration/test_builtins.py rename to tests/integration/test_lang/test_builtins.py index 4b84197..2b06afd 100644 --- a/tests/integration/test_builtins.py +++ b/tests/integration/test_lang/test_builtins.py @@ -2,8 +2,8 @@ import sys import pytest -from .helpers import Suite, write_header -from .runners import RunnerPywasm +from ..helpers import Suite, write_header +from ..runners import RunnerPywasm def setup_interpreter(phash_code: str) -> RunnerPywasm: runner = RunnerPywasm(phash_code) diff --git a/tests/integration/test_lang/test_bytes.py b/tests/integration/test_lang/test_bytes.py new file mode 100644 index 0000000..af3f78d --- /dev/null +++ b/tests/integration/test_lang/test_bytes.py @@ -0,0 +1,84 @@ +import pytest + +from phasm.type3.entry import Type3Exception + +from ..helpers import Suite + +@pytest.mark.integration_test +def test_bytes_address(): + code_py = """ +@exported +def testEntry(f: bytes) -> bytes: + return f +""" + + result = Suite(code_py).run_code(b'This is a test') + + # THIS DEPENDS ON THE ALLOCATOR + # A different allocator will return a different value + assert 20 == result.returned_value + +@pytest.mark.integration_test +def test_bytes_length(): + code_py = """ +@exported +def testEntry(f: bytes) -> u32: + return len(f) +""" + + result = Suite(code_py).run_code(b'This yet is another test') + + assert 24 == result.returned_value + +@pytest.mark.integration_test +def test_bytes_index(): + code_py = """ +@exported +def testEntry(f: bytes) -> u8: + return f[8] +""" + + result = Suite(code_py).run_code(b'This is another test') + + assert 0x61 == result.returned_value + +@pytest.mark.integration_test +def test_bytes_index_out_of_bounds(): + code_py = """ +@exported +def testEntry(f: bytes) -> u8: + return f[50] +""" + + result = Suite(code_py).run_code(b'Short', b'Long' * 100) + + assert 0 == result.returned_value + +@pytest.mark.integration_test +def test_function_call_element_ok(): + code_py = """ +@exported +def testEntry(f: bytes) -> u8: + return helper(f[0]) + +def helper(x: u8) -> u8: + return x +""" + + result = Suite(code_py).run_code(b'Short') + + assert 83 == result.returned_value + +@pytest.mark.integration_test +def test_function_call_element_type_mismatch(): + code_py = """ +@exported +def testEntry(f: bytes) -> u64: + return helper(f[0]) + +def helper(x: u64) -> u64: + return x +""" + + with pytest.raises(Type3Exception, match=r'u64 must be u8 instead'): + Suite(code_py).run_code() diff --git a/tests/integration/test_lang/test_if.py b/tests/integration/test_lang/test_if.py new file mode 100644 index 0000000..5d77eb9 --- /dev/null +++ b/tests/integration/test_lang/test_if.py @@ -0,0 +1,71 @@ +import pytest + +from ..helpers import Suite + +@pytest.mark.integration_test +@pytest.mark.parametrize('inp', [9, 10, 11, 12]) +def test_if_simple(inp): + code_py = """ +@exported +def testEntry(a: i32) -> i32: + if a > 10: + return 15 + + return 3 +""" + exp_result = 15 if inp > 10 else 3 + + suite = Suite(code_py) + + result = suite.run_code(inp) + assert exp_result == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.skip('Such a return is not how things should be') +def test_if_complex(): + code_py = """ +@exported +def testEntry(a: i32) -> i32: + if a > 10: + return 10 + elif a > 0: + return a + else: + return 0 + + return -1 # Required due to function type +""" + + suite = Suite(code_py) + + assert 10 == suite.run_code(20).returned_value + assert 10 == suite.run_code(10).returned_value + + assert 8 == suite.run_code(8).returned_value + + assert 0 == suite.run_code(0).returned_value + assert 0 == suite.run_code(-1).returned_value + +@pytest.mark.integration_test +def test_if_nested(): + code_py = """ +@exported +def testEntry(a: i32, b: i32) -> i32: + if a > 11: + if b > 11: + return 3 + + return 2 + + if b > 11: + return 1 + + return 0 +""" + + suite = Suite(code_py) + + assert 3 == suite.run_code(20, 20).returned_value + assert 2 == suite.run_code(20, 10).returned_value + assert 1 == suite.run_code(10, 20).returned_value + assert 0 == suite.run_code(10, 10).returned_value diff --git a/tests/integration/test_lang/test_interface.py b/tests/integration/test_lang/test_interface.py new file mode 100644 index 0000000..a8dfc32 --- /dev/null +++ b/tests/integration/test_lang/test_interface.py @@ -0,0 +1,79 @@ +import pytest + +from phasm.type3.entry import Type3Exception + +from ..helpers import Suite + +@pytest.mark.integration_test +def test_imported_ok(): + code_py = """ +@imported +def helper(mul: i32) -> i32: + pass + +@exported +def testEntry() -> i32: + return helper(2) +""" + + def helper(mul: int) -> int: + return 4238 * mul + + result = Suite(code_py).run_code( + runtime='wasmer', + imports={ + 'helper': helper, + } + ) + + assert 8476 == result.returned_value + +@pytest.mark.integration_test +def test_imported_side_effect_no_return(): + code_py = """ +@imported +def helper(mul: u8) -> None: + pass + +@exported +def testEntry() -> None: + return helper(3) +""" + prop = None + + def helper(mul: int) -> None: + nonlocal prop + prop = mul + + result = Suite(code_py).run_code( + runtime='wasmer', + imports={ + 'helper': helper, + } + ) + + assert None is result.returned_value + assert 3 == prop + +@pytest.mark.integration_test +def test_imported_type_mismatch(): + code_py = """ +@imported +def helper(mul: u8) -> u8: + pass + +@exported +def testEntry(x: u32) -> u8: + return helper(x) +""" + + def helper(mul: int) -> int: + return 4238 * mul + + with pytest.raises(Type3Exception, match=r'u32 must be u8 instead'): + Suite(code_py).run_code( + runtime='wasmer', + imports={ + 'helper': helper, + } + ) diff --git a/tests/integration/test_lang/test_primitives.py b/tests/integration/test_lang/test_primitives.py new file mode 100644 index 0000000..41a8211 --- /dev/null +++ b/tests/integration/test_lang/test_primitives.py @@ -0,0 +1,486 @@ +import pytest + +from phasm.type3.entry import Type3Exception + +from ..helpers import Suite +from ..constants import ALL_INT_TYPES, ALL_FLOAT_TYPES, COMPLETE_INT_TYPES, TYPE_MAP + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_INT_TYPES) +def test_expr_constant_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 13 +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_expr_constant_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 32.125 +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +def test_expr_constant_literal_does_not_fit(): + code_py = """ +@exported +def testEntry() -> u8: + return 1000 +""" + + with pytest.raises(Type3Exception, match=r'Must fit in 1 byte\(s\)'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_INT_TYPES) +def test_module_constant_int(type_): + code_py = f""" +CONSTANT: {type_} = 13 + +@exported +def testEntry() -> {type_}: + return CONSTANT +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_module_constant_float(type_): + code_py = f""" +CONSTANT: {type_} = 32.125 + +@exported +def testEntry() -> {type_}: + return CONSTANT +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + +@pytest.mark.integration_test +def test_module_constant_type_failure(): + code_py = """ +CONSTANT: u8 = 1000 + +@exported +def testEntry() -> u32: + return 14 +""" + + with pytest.raises(Type3Exception, match=r'Must fit in 1 byte\(s\)'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u32', 'u64']) # FIXME: Support u8, requires an extra AND operation +def test_logical_left_shift(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 << 3 +""" + + result = Suite(code_py).run_code() + + assert 80 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u32', 'u64']) +def test_logical_right_shift_left_bit_zero(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 >> 3 +""" + + # Check with wasmtime, as other engines don't mind if the type + # doesn't match. They'll complain when: (>>) : u32 -> u64 -> u32 + result = Suite(code_py).run_code(runtime='wasmtime') + + assert 1 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +def test_logical_right_shift_left_bit_one(): + code_py = """ +@exported +def testEntry() -> u32: + return 4294967295 >> 16 +""" + + result = Suite(code_py).run_code() + + assert 0xFFFF == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) +def test_bitwise_or_uint(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 | 3 +""" + + result = Suite(code_py).run_code() + + assert 11 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +def test_bitwise_or_inv_type(): + code_py = """ +@exported +def testEntry() -> f64: + return 10.0 | 3.0 +""" + + with pytest.raises(Type3Exception, match='f64 does not implement the BitWiseOperation type class'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_bitwise_or_type_mismatch(): + code_py = """ +CONSTANT1: u32 = 3 +CONSTANT2: u64 = 3 + +@exported +def testEntry() -> u64: + return CONSTANT1 | CONSTANT2 +""" + + with pytest.raises(Type3Exception, match='u64 must be u32 instead'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) +def test_bitwise_xor(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 ^ 3 +""" + + result = Suite(code_py).run_code() + + assert 9 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) +def test_bitwise_and(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 & 3 +""" + + result = Suite(code_py).run_code() + + assert 2 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_addition_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 + 3 +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_addition_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 32.0 + 0.125 +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_subtraction_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 - 3 +""" + + result = Suite(code_py).run_code() + + assert 7 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_subtraction_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 100.0 - 67.875 +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.skip('TODO: Runtimes return a signed value, which is difficult to test') +@pytest.mark.parametrize('type_', ('u32', 'u64')) # FIXME: u8 +def test_subtraction_underflow(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 - 11 +""" + + result = Suite(code_py).run_code() + + assert 0 < result.returned_value + +# TODO: Multiplication + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_division_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 / 3 +""" + + result = Suite(code_py).run_code() + + assert 3 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_division_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10.0 / 8.0 +""" + + result = Suite(code_py).run_code() + + assert 1.25 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_division_zero_let_it_crash_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 / 0 +""" + + # WebAssembly dictates that integer division is a partial operator (e.g. unreachable for 0) + # https://www.w3.org/TR/wasm-core-1/#-hrefop-idiv-umathrmidiv_u_n-i_1-i_2 + with pytest.raises(Exception): + Suite(code_py).run_code() + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_division_zero_let_it_crash_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10.0 / 0.0 +""" + + # WebAssembly dictates that float division follows the IEEE rules + # https://www.w3.org/TR/wasm-core-1/#-hrefop-fdivmathrmfdiv_n-z_1-z_2 + result = Suite(code_py).run_code() + assert float('+inf') == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['f32', 'f64']) +def test_builtins_sqrt(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return sqrt(25.0) +""" + + result = Suite(code_py).run_code() + + assert 5 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', TYPE_MAP.keys()) +def test_function_argument(type_): + code_py = f""" +@exported +def testEntry(a: {type_}) -> {type_}: + return a +""" + + result = Suite(code_py).run_code(125) + + assert 125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.skip('TODO') +def test_explicit_positive_number(): + code_py = """ +@exported +def testEntry() -> i32: + return +523 +""" + + result = Suite(code_py).run_code() + + assert 523 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.skip('TODO') +def test_explicit_negative_number(): + code_py = """ +@exported +def testEntry() -> i32: + return -19 +""" + + result = Suite(code_py).run_code() + + assert -19 == result.returned_value + +@pytest.mark.integration_test +def test_call_no_args(): + code_py = """ +def helper() -> i32: + return 19 + +@exported +def testEntry() -> i32: + return helper() +""" + + result = Suite(code_py).run_code() + + assert 19 == result.returned_value + +@pytest.mark.integration_test +def test_call_pre_defined(): + code_py = """ +def helper(left: i32) -> i32: + return left + +@exported +def testEntry() -> i32: + return helper(13) +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + +@pytest.mark.integration_test +def test_call_post_defined(): + code_py = """ +@exported +def testEntry() -> i32: + return helper(10, 3) + +def helper(left: i32, right: i32) -> i32: + return left - right +""" + + result = Suite(code_py).run_code() + + assert 7 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_call_with_expression_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return helper(10 + 20, 3 + 5) + +def helper(left: {type_}, right: {type_}) -> {type_}: + return left - right +""" + + result = Suite(code_py).run_code() + + assert 22 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_call_with_expression_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return helper(10.078125 + 90.046875, 63.0 + 5.0) + +def helper(left: {type_}, right: {type_}) -> {type_}: + return left - right +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +def test_call_invalid_return_type(): + code_py = """ +def helper() -> i64: + return 19 + +@exported +def testEntry() -> i32: + return helper() +""" + + with pytest.raises(Type3Exception, match=r'i64 must be i32 instead'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_call_invalid_arg_type(): + code_py = """ +def helper(left: u8) -> u8: + return left + +@exported +def testEntry() -> u8: + return helper(500) +""" + + with pytest.raises(Type3Exception, match=r'Must fit in 1 byte\(s\)'): + Suite(code_py).run_code() diff --git a/tests/integration/test_lang/test_static_array.py b/tests/integration/test_lang/test_static_array.py new file mode 100644 index 0000000..4d32cd8 --- /dev/null +++ b/tests/integration/test_lang/test_static_array.py @@ -0,0 +1,195 @@ +import pytest + +from phasm.type3.entry import Type3Exception + +from ..constants import ( + ALL_FLOAT_TYPES, ALL_INT_TYPES, COMPLETE_INT_TYPES, COMPLETE_NUMERIC_TYPES, TYPE_MAP +) +from ..helpers import Suite + +@pytest.mark.integration_test +def test_module_constant_def(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 80, ) + +@exported +def testEntry() -> i32: + return 0 +""" + + result = Suite(code_py).run_code() + + assert 0 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_INT_TYPES) +def test_module_constant_3(type_): + code_py = f""" +CONSTANT: {type_}[3] = (24, 57, 80, ) + +@exported +def testEntry() -> {type_}: + return CONSTANT[1] +""" + + result = Suite(code_py).run_code() + + assert 57 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_function_call_int(type_): + code_py = f""" +CONSTANT: {type_}[3] = (24, 57, 80, ) + +@exported +def testEntry() -> {type_}: + return helper(CONSTANT) + +def helper(array: {type_}[3]) -> {type_}: + return array[0] + array[1] + array[2] +""" + + result = Suite(code_py).run_code() + + assert 161 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_function_call_float(type_): + code_py = f""" +CONSTANT: {type_}[3] = (24.0, 57.5, 80.75, ) + +@exported +def testEntry() -> {type_}: + return helper(CONSTANT) + +def helper(array: {type_}[3]) -> {type_}: + return array[0] + array[1] + array[2] +""" + + result = Suite(code_py).run_code() + + assert 162.25 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +def test_function_call_element_ok(): + code_py = """ +CONSTANT: u64[3] = (250, 250000, 250000000, ) + +@exported +def testEntry() -> u64: + return helper(CONSTANT[0]) + +def helper(x: u64) -> u64: + return x +""" + + result = Suite(code_py).run_code() + + assert 250 == result.returned_value + +@pytest.mark.integration_test +def test_function_call_element_type_mismatch(): + code_py = """ +CONSTANT: u64[3] = (250, 250000, 250000000, ) + +@exported +def testEntry() -> u8: + return helper(CONSTANT[0]) + +def helper(x: u8) -> u8: + return x +""" + + with pytest.raises(Type3Exception, match=r'u8 must be u64 instead'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_module_constant_type_mismatch_bitwidth(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 280, ) +""" + + with pytest.raises(Type3Exception, match=r'Must fit in 1 byte\(s\)'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_return_as_int(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 80, ) + +def testEntry() -> u32: + return CONSTANT +""" + + with pytest.raises(Type3Exception, match=r'static_array \(u8\) \(3\) must be u32 instead'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_module_constant_type_mismatch_not_subscriptable(): + code_py = """ +CONSTANT: u8 = 24 + +@exported +def testEntry() -> u8: + return CONSTANT[0] +""" + + with pytest.raises(Type3Exception, match='u8 cannot be subscripted'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_module_constant_type_mismatch_index_out_of_range_constant(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 80, ) + +@exported +def testEntry() -> u8: + return CONSTANT[3] +""" + + with pytest.raises(Type3Exception, match='3 must be less or equal than 2'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_module_constant_type_mismatch_index_out_of_range_variable(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 80, ) + +@exported +def testEntry(x: u32) -> u8: + return CONSTANT[x] +""" + + with pytest.raises(RuntimeError): + Suite(code_py).run_code(3) + +@pytest.mark.integration_test +def test_static_array_constant_too_few_values(): + code_py = """ +CONSTANT: u8[4] = (24, 57, ) + +@exported +def testEntry() -> i32: + return 0 +""" + + with pytest.raises(Type3Exception, match='Member count mismatch'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_static_array_constant_too_many_values(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 1, 1, ) + +@exported +def testEntry() -> i32: + return 0 +""" + + with pytest.raises(Type3Exception, match='Member count mismatch'): + Suite(code_py).run_code() diff --git a/tests/integration/test_lang/test_struct.py b/tests/integration/test_lang/test_struct.py new file mode 100644 index 0000000..bcef772 --- /dev/null +++ b/tests/integration/test_lang/test_struct.py @@ -0,0 +1,154 @@ +import pytest + +from phasm.type3.entry import Type3Exception + +from ..constants import ( + ALL_INT_TYPES, TYPE_MAP +) +from ..helpers import Suite + +@pytest.mark.integration_test +def test_module_constant_def(): + code_py = """ +class SomeStruct: + value0: u8 + value1: u32 + value2: u64 + +CONSTANT: SomeStruct = SomeStruct(250, 250000, 250000000) + +@exported +def testEntry() -> i32: + return 0 +""" + + result = Suite(code_py).run_code() + + assert 0 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_INT_TYPES) +def test_module_constant(type_): + code_py = f""" +class CheckedValue: + value: {type_} + +CONSTANT: CheckedValue = CheckedValue(24) + +@exported +def testEntry() -> {type_}: + return CONSTANT.value +""" + + result = Suite(code_py).run_code() + + assert 24 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_INT_TYPES) +def test_struct_0(type_): + code_py = f""" +class CheckedValue: + value: {type_} + +@exported +def testEntry() -> {type_}: + return helper(CheckedValue(23)) + +def helper(cv: CheckedValue) -> {type_}: + return cv.value +""" + + result = Suite(code_py).run_code() + + assert 23 == result.returned_value + +@pytest.mark.integration_test +def test_struct_1(): + code_py = """ +class Rectangle: + height: i32 + width: i32 + border: i32 + +@exported +def testEntry() -> i32: + return helper(Rectangle(100, 150, 2)) + +def helper(shape: Rectangle) -> i32: + return shape.height + shape.width + shape.border +""" + + result = Suite(code_py).run_code() + + assert 252 == result.returned_value + +@pytest.mark.integration_test +def test_struct_2(): + code_py = """ +class Rectangle: + height: i32 + width: i32 + border: i32 + +@exported +def testEntry() -> i32: + return helper(Rectangle(100, 150, 2), Rectangle(200, 90, 3)) + +def helper(shape1: Rectangle, shape2: Rectangle) -> i32: + return shape1.height + shape1.width + shape1.border + shape2.height + shape2.width + shape2.border +""" + + result = Suite(code_py).run_code() + + assert 545 == result.returned_value + +@pytest.mark.integration_test +def test_returned_struct(): + code_py = """ +class CheckedValue: + value: u8 + +CONSTANT: CheckedValue = CheckedValue(199) + +def helper() -> CheckedValue: + return CONSTANT + +def helper2(x: CheckedValue) -> u8: + return x.value + +@exported +def testEntry() -> u8: + return helper2(helper()) +""" + + result = Suite(code_py).run_code() + + assert 199 == result.returned_value + +@pytest.mark.integration_test +def test_type_mismatch_arg_module_constant(): + code_py = """ +class Struct: + param: f32 + +STRUCT: Struct = Struct(1) +""" + + with pytest.raises(Type3Exception, match='Must be real'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) +def test_type_mismatch_struct_member(type_): + code_py = f""" +class Struct: + param: {type_} + +def testEntry(arg: Struct) -> (i32, i32, ): + return arg.param +""" + + with pytest.raises(Type3Exception, match=type_ + r' must be tuple \(i32\) \(i32\) instead'): + Suite(code_py).run_code() diff --git a/tests/integration/test_lang/test_tuple.py b/tests/integration/test_lang/test_tuple.py new file mode 100644 index 0000000..a238932 --- /dev/null +++ b/tests/integration/test_lang/test_tuple.py @@ -0,0 +1,188 @@ +import pytest + +from phasm.type3.entry import Type3Exception + +from ..constants import ALL_FLOAT_TYPES, COMPLETE_INT_TYPES, TYPE_MAP +from ..helpers import Suite + +@pytest.mark.integration_test +def test_module_constant_def(): + code_py = """ +CONSTANT: (u8, u32, u64, ) = (250, 250000, 250000000, ) + +@exported +def testEntry() -> i32: + return 0 +""" + + result = Suite(code_py).run_code() + + assert 0 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64', ]) +def test_module_constant_1(type_): + code_py = f""" +CONSTANT: ({type_}, ) = (65, ) + +@exported +def testEntry() -> {type_}: + return CONSTANT[0] +""" + + result = Suite(code_py).run_code() + + assert 65 == result.returned_value + +@pytest.mark.integration_test +def test_module_constant_6(): + code_py = """ +CONSTANT: (u8, u8, u32, u32, u64, u64, ) = (11, 22, 3333, 4444, 555555, 666666, ) + +@exported +def testEntry() -> u32: + return CONSTANT[2] +""" + + result = Suite(code_py).run_code() + + assert 3333 == result.returned_value + +@pytest.mark.integration_test +def test_function_call_element_ok(): + code_py = """ +CONSTANT: (u8, u32, u64, ) = (250, 250000, 250000000, ) + +@exported +def testEntry() -> u64: + return helper(CONSTANT[2]) + +def helper(x: u64) -> u64: + return x +""" + + result = Suite(code_py).run_code() + + assert 250000000 == result.returned_value + +@pytest.mark.integration_test +def test_function_call_element_type_mismatch(): + code_py = """ +CONSTANT: (u8, u32, u64, ) = (250, 250000, 250000000, ) + +@exported +def testEntry() -> u8: + return helper(CONSTANT[2]) + +def helper(x: u8) -> u8: + return x +""" + + with pytest.raises(Type3Exception, match=r'u8 must be u64 instead'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_tuple_simple_constructor_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return helper((24, 57, 80, )) + +def helper(vector: ({type_}, {type_}, {type_}, )) -> {type_}: + return vector[0] + vector[1] + vector[2] +""" + + result = Suite(code_py).run_code() + + assert 161 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_tuple_simple_constructor_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return helper((1.0, 2.0, 3.0, )) + +def helper(v: ({type_}, {type_}, {type_}, )) -> {type_}: + return sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]) +""" + + result = Suite(code_py).run_code() + + assert 3.74 < result.returned_value < 3.75 + +@pytest.mark.integration_test +@pytest.mark.skip('SIMD support is but a dream') +def test_tuple_i32x4(): + code_py = """ +@exported +def testEntry() -> i32x4: + return (51, 153, 204, 0, ) +""" + + result = Suite(code_py).run_code() + + assert (1, 2, 3, 0) == result.returned_value + +@pytest.mark.integration_test +def test_assign_to_tuple_with_tuple(): + code_py = """ +CONSTANT: (u32, ) = 0 +""" + + with pytest.raises(Type3Exception, match='Must be tuple'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_tuple_constant_too_few_values(): + code_py = """ +CONSTANT: (u32, u8, u8, ) = (24, 57, ) +""" + + with pytest.raises(Type3Exception, match='Tuple element count mismatch'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_tuple_constant_too_many_values(): + code_py = """ +CONSTANT: (u32, u8, u8, ) = (24, 57, 1, 1, ) +""" + + with pytest.raises(Type3Exception, match='Tuple element count mismatch'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_tuple_constant_type_mismatch(): + code_py = """ +CONSTANT: (u32, u8, u8, ) = (24, 4000, 1, ) +""" + + with pytest.raises(Type3Exception, match=r'Must fit in 1 byte\(s\)'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_tuple_must_use_literal_for_indexing(): + code_py = """ +CONSTANT: u32 = 0 + +@exported +def testEntry(x: (u8, u32, u64)) -> u64: + return x[CONSTANT] +""" + + with pytest.raises(Type3Exception, match='Must index with literal'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_tuple_must_use_integer_for_indexing(): + code_py = """ +@exported +def testEntry(x: (u8, u32, u64)) -> u64: + return x[0.0] +""" + + with pytest.raises(Type3Exception, match='Must index with integer literal'): + Suite(code_py).run_code() diff --git a/tests/integration/test_runtime_checks.py b/tests/integration/test_runtime_checks.py deleted file mode 100644 index 97d6542..0000000 --- a/tests/integration/test_runtime_checks.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest - -from .helpers import Suite - -@pytest.mark.integration_test -def test_bytes_index_out_of_bounds(): - code_py = """ -@exported -def testEntry(f: bytes) -> u8: - return f[50] -""" - - result = Suite(code_py).run_code(b'Short', b'Long' * 100) - - assert 0 == result.returned_value - -@pytest.mark.integration_test -def test_static_array_index_out_of_bounds(): - code_py = """ -CONSTANT0: u32[3] = (24, 57, 80, ) - -CONSTANT1: u32[16] = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ) - -@exported -def testEntry() -> u32: - return CONSTANT0[16] -""" - - result = Suite(code_py).run_code() - - assert 0 == result.returned_value diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py deleted file mode 100644 index f0c2993..0000000 --- a/tests/integration/test_simple.py +++ /dev/null @@ -1,571 +0,0 @@ -import pytest - -from .helpers import Suite - -TYPE_MAP = { - 'u8': int, - 'u32': int, - 'u64': int, - 'i32': int, - 'i64': int, - 'f32': float, - 'f64': float, -} - -COMPLETE_SIMPLE_TYPES = [ - 'u32', 'u64', - 'i32', 'i64', - 'f32', 'f64', -] - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', TYPE_MAP.keys()) -def test_return(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 13 -""" - - result = Suite(code_py).run_code() - - assert 13 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_addition(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 + 3 -""" - - result = Suite(code_py).run_code() - - assert 13 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_subtraction(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 - 3 -""" - - result = Suite(code_py).run_code() - - assert 7 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u32', 'u64']) # FIXME: Support u8, requires an extra AND operation -def test_logical_left_shift(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 << 3 -""" - - result = Suite(code_py).run_code() - - assert 80 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_logical_right_shift(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 >> 3 -""" - - result = Suite(code_py).run_code() - - assert 1 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_bitwise_or(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 | 3 -""" - - result = Suite(code_py).run_code() - - assert 11 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_bitwise_xor(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 ^ 3 -""" - - result = Suite(code_py).run_code() - - assert 9 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_bitwise_and(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 & 3 -""" - - result = Suite(code_py).run_code() - - assert 2 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['f32', 'f64']) -def test_buildins_sqrt(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return sqrt(25) -""" - - result = Suite(code_py).run_code() - - assert 5 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', TYPE_MAP.keys()) -def test_arg(type_): - code_py = f""" -@exported -def testEntry(a: {type_}) -> {type_}: - return a -""" - - result = Suite(code_py).run_code(125) - - assert 125 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.skip('Do we want it to work like this?') -def test_i32_to_i64(): - code_py = """ -@exported -def testEntry(a: i32) -> i64: - return a -""" - - result = Suite(code_py).run_code(125) - - assert 125 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('Do we want it to work like this?') -def test_i32_plus_i64(): - code_py = """ -@exported -def testEntry(a: i32, b: i64) -> i64: - return a + b -""" - - result = Suite(code_py).run_code(125, 100) - - assert 225 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('Do we want it to work like this?') -def test_f32_to_f64(): - code_py = """ -@exported -def testEntry(a: f32) -> f64: - return a -""" - - result = Suite(code_py).run_code(125.5) - - assert 125.5 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('Do we want it to work like this?') -def test_f32_plus_f64(): - code_py = """ -@exported -def testEntry(a: f32, b: f64) -> f64: - return a + b -""" - - result = Suite(code_py).run_code(125.5, 100.25) - - assert 225.75 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('TODO') -def test_uadd(): - code_py = """ -@exported -def testEntry() -> i32: - return +523 -""" - - result = Suite(code_py).run_code() - - assert 523 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('TODO') -def test_usub(): - code_py = """ -@exported -def testEntry() -> i32: - return -19 -""" - - result = Suite(code_py).run_code() - - assert -19 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('inp', [9, 10, 11, 12]) -def test_if_simple(inp): - code_py = """ -@exported -def testEntry(a: i32) -> i32: - if a > 10: - return 15 - - return 3 -""" - exp_result = 15 if inp > 10 else 3 - - suite = Suite(code_py) - - result = suite.run_code(inp) - assert exp_result == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('Such a return is not how things should be') -def test_if_complex(): - code_py = """ -@exported -def testEntry(a: i32) -> i32: - if a > 10: - return 10 - elif a > 0: - return a - else: - return 0 - - return -1 # Required due to function type -""" - - suite = Suite(code_py) - - assert 10 == suite.run_code(20).returned_value - assert 10 == suite.run_code(10).returned_value - - assert 8 == suite.run_code(8).returned_value - - assert 0 == suite.run_code(0).returned_value - assert 0 == suite.run_code(-1).returned_value - -@pytest.mark.integration_test -def test_if_nested(): - code_py = """ -@exported -def testEntry(a: i32, b: i32) -> i32: - if a > 11: - if b > 11: - return 3 - - return 2 - - if b > 11: - return 1 - - return 0 -""" - - suite = Suite(code_py) - - assert 3 == suite.run_code(20, 20).returned_value - assert 2 == suite.run_code(20, 10).returned_value - assert 1 == suite.run_code(10, 20).returned_value - assert 0 == suite.run_code(10, 10).returned_value - -@pytest.mark.integration_test -def test_call_pre_defined(): - code_py = """ -def helper(left: i32, right: i32) -> i32: - return left + right - -@exported -def testEntry() -> i32: - return helper(10, 3) -""" - - result = Suite(code_py).run_code() - - assert 13 == result.returned_value - -@pytest.mark.integration_test -def test_call_post_defined(): - code_py = """ -@exported -def testEntry() -> i32: - return helper(10, 3) - -def helper(left: i32, right: i32) -> i32: - return left - right -""" - - result = Suite(code_py).run_code() - - assert 7 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_call_with_expression(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return helper(10 + 20, 3 + 5) - -def helper(left: {type_}, right: {type_}) -> {type_}: - return left - right -""" - - result = Suite(code_py).run_code() - - assert 22 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.skip('Not yet implemented') -def test_assign(): - code_py = """ - -@exported -def testEntry() -> i32: - a: i32 = 8947 - return a -""" - - result = Suite(code_py).run_code() - - assert 8947 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', TYPE_MAP.keys()) -def test_struct_0(type_): - code_py = f""" -class CheckedValue: - value: {type_} - -@exported -def testEntry() -> {type_}: - return helper(CheckedValue(23)) - -def helper(cv: CheckedValue) -> {type_}: - return cv.value -""" - - result = Suite(code_py).run_code() - - assert 23 == result.returned_value - -@pytest.mark.integration_test -def test_struct_1(): - code_py = """ -class Rectangle: - height: i32 - width: i32 - border: i32 - -@exported -def testEntry() -> i32: - return helper(Rectangle(100, 150, 2)) - -def helper(shape: Rectangle) -> i32: - return shape.height + shape.width + shape.border -""" - - result = Suite(code_py).run_code() - - assert 252 == result.returned_value - -@pytest.mark.integration_test -def test_struct_2(): - code_py = """ -class Rectangle: - height: i32 - width: i32 - border: i32 - -@exported -def testEntry() -> i32: - return helper(Rectangle(100, 150, 2), Rectangle(200, 90, 3)) - -def helper(shape1: Rectangle, shape2: Rectangle) -> i32: - return shape1.height + shape1.width + shape1.border + shape2.height + shape2.width + shape2.border -""" - - result = Suite(code_py).run_code() - - assert 545 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_tuple_simple_constructor(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return helper((24, 57, 80, )) - -def helper(vector: ({type_}, {type_}, {type_}, )) -> {type_}: - return vector[0] + vector[1] + vector[2] -""" - - result = Suite(code_py).run_code() - - assert 161 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -def test_tuple_float(): - code_py = """ -@exported -def testEntry() -> f32: - return helper((1.0, 2.0, 3.0, )) - -def helper(v: (f32, f32, f32, )) -> f32: - return sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]) -""" - - result = Suite(code_py).run_code() - - assert 3.74 < result.returned_value < 3.75 - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_static_array_module_constant(type_): - code_py = f""" -CONSTANT: {type_}[3] = (24, 57, 80, ) - -@exported -def testEntry() -> {type_}: - return helper(CONSTANT) - -def helper(array: {type_}[3]) -> {type_}: - return array[0] + array[1] + array[2] -""" - - result = Suite(code_py).run_code() - - assert 161 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_static_array_indexed(type_): - code_py = f""" -CONSTANT: {type_}[3] = (24, 57, 80, ) - -@exported -def testEntry() -> {type_}: - return helper(CONSTANT, 0, 1, 2) - -def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}: - return array[i0] + array[i1] + array[i2] -""" - - result = Suite(code_py).run_code() - - assert 161 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -def test_bytes_address(): - code_py = """ -@exported -def testEntry(f: bytes) -> bytes: - return f -""" - - result = Suite(code_py).run_code(b'This is a test') - - # THIS DEPENDS ON THE ALLOCATOR - # A different allocator will return a different value - assert 20 == result.returned_value - -@pytest.mark.integration_test -def test_bytes_length(): - code_py = """ -@exported -def testEntry(f: bytes) -> i32: - return len(f) -""" - - result = Suite(code_py).run_code(b'This is another test') - - assert 20 == result.returned_value - -@pytest.mark.integration_test -def test_bytes_index(): - code_py = """ -@exported -def testEntry(f: bytes) -> u8: - return f[8] -""" - - result = Suite(code_py).run_code(b'This is another test') - - assert 0x61 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('SIMD support is but a dream') -def test_tuple_i32x4(): - code_py = """ -@exported -def testEntry() -> i32x4: - return (51, 153, 204, 0, ) -""" - - result = Suite(code_py).run_code() - - assert (1, 2, 3, 0) == result.returned_value - -@pytest.mark.integration_test -def test_imported(): - code_py = """ -@imported -def helper(mul: i32) -> i32: - pass - -@exported -def testEntry() -> i32: - return helper(2) -""" - - def helper(mul: int) -> int: - return 4238 * mul - - result = Suite(code_py).run_code( - runtime='wasmer', - imports={ - 'helper': helper, - } - ) - - assert 8476 == result.returned_value diff --git a/tests/integration/test_static_checking.py b/tests/integration/test_static_checking.py deleted file mode 100644 index 1544537..0000000 --- a/tests/integration/test_static_checking.py +++ /dev/null @@ -1,109 +0,0 @@ -import pytest - -from phasm.parser import phasm_parse -from phasm.exceptions import StaticError - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) -def test_type_mismatch_function_argument(type_): - code_py = f""" -def helper(a: {type_}) -> (i32, i32, ): - return a -""" - - with pytest.raises(StaticError, match=f'Static error on line 3: Expected \\(i32, i32, \\), a is actually {type_}'): - phasm_parse(code_py) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) -def test_type_mismatch_struct_member(type_): - code_py = f""" -class Struct: - param: {type_} - -def testEntry(arg: Struct) -> (i32, i32, ): - return arg.param -""" - - with pytest.raises(StaticError, match=f'Static error on line 6: Expected \\(i32, i32, \\), arg.param is actually {type_}'): - phasm_parse(code_py) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) -def test_type_mismatch_tuple_member(type_): - code_py = f""" -def testEntry(arg: ({type_}, )) -> (i32, i32, ): - return arg[0] -""" - - with pytest.raises(StaticError, match=f'Static error on line 3: Expected \\(i32, i32, \\), arg\\[0\\] is actually {type_}'): - phasm_parse(code_py) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) -def test_type_mismatch_function_result(type_): - code_py = f""" -def helper() -> {type_}: - return 1 - -@exported -def testEntry() -> (i32, i32, ): - return helper() -""" - - with pytest.raises(StaticError, match=f'Static error on line 7: Expected \\(i32, i32, \\), helper actually returns {type_}'): - phasm_parse(code_py) - -@pytest.mark.integration_test -def test_tuple_constant_too_few_values(): - code_py = """ -CONSTANT: (u32, u8, u8, ) = (24, 57, ) -""" - - with pytest.raises(StaticError, match='Static error on line 2: Invalid number of tuple values'): - phasm_parse(code_py) - -@pytest.mark.integration_test -def test_tuple_constant_too_many_values(): - code_py = """ -CONSTANT: (u32, u8, u8, ) = (24, 57, 1, 1, ) -""" - - with pytest.raises(StaticError, match='Static error on line 2: Invalid number of tuple values'): - phasm_parse(code_py) - -@pytest.mark.integration_test -def test_tuple_constant_type_mismatch(): - code_py = """ -CONSTANT: (u32, u8, u8, ) = (24, 4000, 1, ) -""" - - with pytest.raises(StaticError, match='Static error on line 2: Integer value out of range; expected 0..255, actual 4000'): - phasm_parse(code_py) - -@pytest.mark.integration_test -def test_static_array_constant_too_few_values(): - code_py = """ -CONSTANT: u8[3] = (24, 57, ) -""" - - with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'): - phasm_parse(code_py) - -@pytest.mark.integration_test -def test_static_array_constant_too_many_values(): - code_py = """ -CONSTANT: u8[3] = (24, 57, 1, 1, ) -""" - - with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'): - phasm_parse(code_py) - -@pytest.mark.integration_test -def test_static_array_constant_type_mismatch(): - code_py = """ -CONSTANT: u8[3] = (24, 4000, 1, ) -""" - - with pytest.raises(StaticError, match='Static error on line 2: Integer value out of range; expected 0..255, actual 4000'): - phasm_parse(code_py) diff --git a/tests/integration/test_stdlib/__init__.py b/tests/integration/test_stdlib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_stdlib_alloc.py b/tests/integration/test_stdlib/test_alloc.py similarity index 95% rename from tests/integration/test_stdlib_alloc.py rename to tests/integration/test_stdlib/test_alloc.py index da8ccea..96d1fd6 100644 --- a/tests/integration/test_stdlib_alloc.py +++ b/tests/integration/test_stdlib/test_alloc.py @@ -2,8 +2,8 @@ import sys import pytest -from .helpers import write_header -from .runners import RunnerPywasm3 as Runner +from ..helpers import write_header +from ..runners import RunnerPywasm3 as Runner def setup_interpreter(phash_code: str) -> Runner: runner = Runner(phash_code)