Adds runner classes to tests, implements xor for u8, u32, u64

This commit is contained in:
Johan B.W. de Vries 2022-08-06 20:11:39 +02:00
parent c5e2744b3e
commit 253974df24
8 changed files with 258 additions and 46 deletions

View File

@ -31,7 +31,7 @@ lint: venv/.done
venv/bin/pylint phasm venv/bin/pylint phasm
typecheck: venv/.done typecheck: venv/.done
venv/bin/mypy --strict phasm venv/bin/mypy --strict phasm tests/integration/runners.py
venv/.done: requirements.txt venv/.done: requirements.txt
python3.8 -m venv venv python3.8 -m venv venv

2
mypy.ini Normal file
View File

@ -0,0 +1,2 @@
[mypy]
mypy_path=stubs

View File

@ -78,11 +78,19 @@ OPERATOR_MAP = {
'==': 'eq', '==': 'eq',
} }
U8_OPERATOR_MAP = {
# Under the hood, this is an i32
# Implementing XOR is fine since the 3 remaining
# bytes stay zero after this operation
'^': 'xor',
}
U32_OPERATOR_MAP = { U32_OPERATOR_MAP = {
'<': 'lt_u', '<': 'lt_u',
'>': 'gt_u', '>': 'gt_u',
'<=': 'le_u', '<=': 'le_u',
'>=': 'ge_u', '>=': 'ge_u',
'^': 'xor',
} }
U64_OPERATOR_MAP = { U64_OPERATOR_MAP = {
@ -90,6 +98,7 @@ U64_OPERATOR_MAP = {
'>': 'gt_u', '>': 'gt_u',
'<=': 'le_u', '<=': 'le_u',
'>=': 'ge_u', '>=': 'ge_u',
'^': 'xor',
} }
I32_OPERATOR_MAP = { I32_OPERATOR_MAP = {
@ -146,6 +155,10 @@ def expression(inp: ourlang.Expression) -> Statements:
yield from expression(inp.left) yield from expression(inp.left)
yield from expression(inp.right) yield from expression(inp.right)
if isinstance(inp.type, typing.TypeUInt8):
if operator := U8_OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'i32.{operator}')
return
if isinstance(inp.type, typing.TypeUInt32): if isinstance(inp.type, typing.TypeUInt32):
if operator := OPERATOR_MAP.get(inp.operator, None): if operator := OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'i32.{operator}') yield wasm.Statement(f'i32.{operator}')

View File

@ -242,6 +242,8 @@ class OurVisitor:
operator = '-' operator = '-'
elif isinstance(node.op, ast.Mult): elif isinstance(node.op, ast.Mult):
operator = '*' operator = '*'
elif isinstance(node.op, ast.BitXor):
operator = '^'
else: else:
raise NotImplementedError(f'Operator {node.op}') raise NotImplementedError(f'Operator {node.op}')

23
stubs/wasm3.pyi Normal file
View File

@ -0,0 +1,23 @@
from typing import Any, Callable
class Module:
...
class Runtime:
...
def load(self, wasm_bin: Module) -> None:
...
def get_memory(self, memid: int) -> memoryview:
...
def find_function(self, name: str) -> Callable[[Any], Any]:
...
class Environment:
def new_runtime(self, mem_size: int) -> Runtime:
...
def parse_module(self, wasm_bin: bytes) -> Module:
...

View File

@ -0,0 +1,168 @@
"""
Runners to help run WebAssembly code on various interpreters
"""
from typing import Any, TextIO
import os
import subprocess
import sys
import tempfile
from phasm.compiler import phasm_compile
from phasm.parser import phasm_parse
from phasm import ourlang
from phasm import wasm
import wasm3
def wat2wasm(code_wat: str) -> bytes:
"""
Converts the given WebAssembly Assembly code into WebAssembly Binary
"""
path = os.environ.get('WAT2WASM', 'wat2wasm')
with tempfile.NamedTemporaryFile('w+t') as input_fp:
input_fp.write(code_wat)
input_fp.flush()
with tempfile.NamedTemporaryFile('w+b') as output_fp:
subprocess.run(
[
path,
input_fp.name,
'-o',
output_fp.name,
],
check=True,
)
output_fp.seek(0)
return output_fp.read()
class RunnerBase:
"""
Base class
"""
phasm_code: str
phasm_ast: ourlang.Module
wasm_ast: wasm.Module
wasm_asm: str
wasm_bin: bytes
def __init__(self, phasm_code: str) -> None:
self.phasm_code = phasm_code
def dump_phasm_code(self, textio: TextIO) -> None:
"""
Dumps the input Phasm code for debugging
"""
_dump_code(textio, self.phasm_code)
def parse(self) -> None:
"""
Parses the Phasm code into an AST
"""
self.phasm_ast = phasm_parse(self.phasm_code)
def compile_ast(self) -> None:
"""
Compiles the Phasm AST into an WebAssembly AST
"""
self.wasm_ast = phasm_compile(self.phasm_ast)
def compile_wat(self) -> None:
"""
Compiles the WebAssembly AST into WebAssembly Assembly code
"""
self.wasm_asm = self.wasm_ast.to_wat()
def dump_wasm_wat(self, textio: TextIO) -> None:
"""
Dumps the intermediate WebAssembly Assembly code for debugging
"""
_dump_code(textio, self.wasm_asm)
def compile_wasm(self) -> None:
"""
Compiles the WebAssembly AST into WebAssembly Binary
"""
self.wasm_bin = wat2wasm(self.wasm_asm)
def interpreter_setup(self) -> None:
"""
Sets up the interpreter
"""
raise NotImplementedError
def interpreter_load(self) -> None:
"""
Loads the code into the interpreter
"""
raise NotImplementedError
def interpreter_dump_memory(self, textio: TextIO) -> None:
"""
Dumps the interpreters memory for debugging
"""
raise NotImplementedError
def call(self, function: str, *args: Any) -> Any:
"""
Calls the given function with the given arguments, returning the result
"""
raise NotImplementedError
class RunnerPywasm3(RunnerBase):
"""
Implements a runner for pywasm3
See https://pypi.org/project/pywasm3/
"""
env: wasm3.Environment
rtime: wasm3.Runtime
mod: wasm3.Module
def interpreter_setup(self) -> None:
self.env = wasm3.Environment()
self.rtime = self.env.new_runtime(1024 * 1024)
def interpreter_load(self) -> None:
self.mod = self.env.parse_module(self.wasm_bin)
self.rtime.load(self.mod)
def interpreter_dump_memory(self, textio: TextIO) -> None:
_dump_memory(textio, self.rtime.get_memory(0))
def call(self, function: str, *args: Any) -> Any:
return self.rtime.find_function(function)(*args)
def _dump_memory(textio: TextIO, mem: bytes) -> None:
line_width = 16
prev_line = None
skip = False
for idx in range(0, len(mem), line_width):
line = ''
for idx2 in range(0, line_width):
line += f'{mem[idx + idx2]:02X}'
if idx2 % 2 == 1:
line += ' '
if prev_line == line:
if not skip:
textio.write('**\n')
skip = True
else:
textio.write(f'{idx:08x} {line}\n')
prev_line = line
def _dump_code(textio: TextIO, text: str) -> None:
line_list = text.split('\n')
line_no_width = len(str(len(line_list)))
for line_no, line_txt in enumerate(line_list):
textio.write('{} {}\n'.format(
str(line_no + 1).zfill(line_no_width),
line_txt,
))

View File

@ -60,6 +60,20 @@ def testEntry() -> {type_}:
assert 7 == result.returned_value assert 7 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value) assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64'])
def test_xor(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10 ^ 3
"""
result = Suite(code_py).run_code()
assert 9 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['f32', 'f64']) @pytest.mark.parametrize('type_', ['f32', 'f64'])
def test_buildins_sqrt(type_): def test_buildins_sqrt(type_):

View File

@ -2,30 +2,25 @@ import sys
import pytest import pytest
import wasm3 from .helpers import DASHES
from .runners import RunnerPywasm3
from phasm.compiler import phasm_compile def setup_interpreter(phash_code: str) -> RunnerPywasm3:
from phasm.parser import phasm_parse runner = RunnerPywasm3(phash_code)
from .helpers import DASHES, wat2wasm, _dump_memory, _write_numbered_lines runner.parse()
runner.compile_ast()
def setup_interpreter(code_phasm): runner.compile_wat()
phasm_module = phasm_parse(code_phasm) runner.compile_wasm()
wasm_module = phasm_compile(phasm_module) runner.interpreter_setup()
code_wat = wasm_module.to_wat() runner.interpreter_load()
sys.stderr.write(f'{DASHES} Phasm {DASHES}\n')
runner.dump_phasm_code(sys.stderr)
sys.stderr.write(f'{DASHES} Assembly {DASHES}\n') sys.stderr.write(f'{DASHES} Assembly {DASHES}\n')
_write_numbered_lines(code_wat) runner.dump_wasm_wat(sys.stderr)
code_wasm = wat2wasm(code_wat) return runner
env = wasm3.Environment()
mod = env.parse_module(code_wasm)
rtime = env.new_runtime(1024 * 1024)
rtime.load(mod)
return rtime
@pytest.mark.integration_test @pytest.mark.integration_test
def test___init__(): def test___init__():
@ -35,20 +30,22 @@ def testEntry() -> u8:
return 13 return 13
""" """
rtime = setup_interpreter(code_py) runner = setup_interpreter(code_py)
memory = runner.rtime.get_memory(0)
# Garbage in the memory so we can test for it
for idx in range(128): for idx in range(128):
rtime.get_memory(0)[idx] = idx memory[idx] = idx
sys.stderr.write(f'{DASHES} Memory (pre run) {DASHES}\n') sys.stderr.write(f'{DASHES} Memory (pre run) {DASHES}\n')
_dump_memory(rtime.get_memory(0)) runner.interpreter_dump_memory(sys.stderr)
rtime.find_function('stdlib.alloc.__init__')() runner.call('stdlib.alloc.__init__')
sys.stderr.write(f'{DASHES} Memory (pre run) {DASHES}\n') sys.stderr.write(f'{DASHES} Memory (post run) {DASHES}\n')
_dump_memory(rtime.get_memory(0)) runner.interpreter_dump_memory(sys.stderr)
memory = rtime.get_memory(0).tobytes() memory = memory.tobytes()
assert ( assert (
b'\xC0\xA1\x00\x00' b'\xC0\xA1\x00\x00'
@ -66,16 +63,13 @@ def testEntry() -> u8:
return 13 return 13
""" """
rtime = setup_interpreter(code_py) runner = setup_interpreter(code_py)
for idx in range(128):
rtime.get_memory(0)[idx] = idx
sys.stderr.write(f'{DASHES} Memory (pre run) {DASHES}\n') sys.stderr.write(f'{DASHES} Memory (pre run) {DASHES}\n')
_dump_memory(rtime.get_memory(0)) runner.interpreter_dump_memory(sys.stderr)
with pytest.raises(RuntimeError, match='unreachable executed'): with pytest.raises(RuntimeError, match='unreachable executed'):
rtime.find_function('stdlib.alloc.__alloc__')(32) runner.call('stdlib.alloc.__alloc__', 32)
@pytest.mark.integration_test @pytest.mark.integration_test
def test___alloc___ok(): def test___alloc___ok():
@ -85,23 +79,19 @@ def testEntry() -> u8:
return 13 return 13
""" """
rtime = setup_interpreter(code_py) runner = setup_interpreter(code_py)
memory = runner.rtime.get_memory(0)
for idx in range(128):
rtime.get_memory(0)[idx] = idx
sys.stderr.write(f'{DASHES} Memory (pre run) {DASHES}\n') sys.stderr.write(f'{DASHES} Memory (pre run) {DASHES}\n')
_dump_memory(rtime.get_memory(0)) runner.interpreter_dump_memory(sys.stderr)
rtime.find_function('stdlib.alloc.__init__')() runner.call('stdlib.alloc.__init__')
offset0 = rtime.find_function('stdlib.alloc.__alloc__')(32) offset0 = runner.call('stdlib.alloc.__alloc__', 32)
offset1 = rtime.find_function('stdlib.alloc.__alloc__')(32) offset1 = runner.call('stdlib.alloc.__alloc__', 32)
offset2 = rtime.find_function('stdlib.alloc.__alloc__')(32) offset2 = runner.call('stdlib.alloc.__alloc__', 32)
sys.stderr.write(f'{DASHES} Memory (pre run) {DASHES}\n') sys.stderr.write(f'{DASHES} Memory (post run) {DASHES}\n')
_dump_memory(rtime.get_memory(0)) runner.interpreter_dump_memory(sys.stderr)
memory = rtime.get_memory(0).tobytes()
assert b'\xC0\xA1\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' == memory[0:12] assert b'\xC0\xA1\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' == memory[0:12]