From 374231d206c802136d9241bbf06e2783a13c5820 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Sat, 25 Jun 2022 20:45:33 +0200 Subject: [PATCH] bytes, u8 types --- py2wasm/compiler.py | 106 +++++++++++++++++++---- py2wasm/ourlang.py | 85 +++++++++++++++++- tests/integration/helpers.py | 47 +++++++++- tests/integration/test_runtime_checks.py | 15 ++++ tests/integration/test_simple.py | 56 ++++++++++-- 5 files changed, 276 insertions(+), 33 deletions(-) create mode 100644 tests/integration/test_runtime_checks.py diff --git a/py2wasm/compiler.py b/py2wasm/compiler.py index 40263a2..4b71bae 100644 --- a/py2wasm/compiler.py +++ b/py2wasm/compiler.py @@ -9,6 +9,11 @@ from . import wasm Statements = Generator[wasm.Statement, None, None] def type_(inp: ourlang.OurType) -> wasm.OurType: + if isinstance(inp, ourlang.OurTypeUInt8): + # WebAssembly has only support for 32 and 64 bits + # So we need to store more memory per byte + return wasm.OurTypeInt32() + if isinstance(inp, ourlang.OurTypeInt32): return wasm.OurTypeInt32() @@ -21,7 +26,7 @@ def type_(inp: ourlang.OurType) -> wasm.OurType: if isinstance(inp, ourlang.OurTypeFloat64): return wasm.OurTypeFloat64() - if isinstance(inp, (ourlang.Struct, ourlang.OurTypeTuple, )): + if isinstance(inp, (ourlang.Struct, ourlang.OurTypeTuple, ourlang.OurTypeBytes)): # Structs and tuples are passed as pointer # And pointers are i32 return wasm.OurTypeInt32() @@ -51,6 +56,10 @@ I64_OPERATOR_MAP = { # TODO: Introduce UInt32 type } def expression(inp: ourlang.Expression) -> Statements: + if isinstance(inp, ourlang.ConstantUInt8): + yield wasm.Statement('i32.const', str(inp.value)) + return + if isinstance(inp, ourlang.ConstantInt32): yield wasm.Statement('i32.const', str(inp.value)) return @@ -112,6 +121,12 @@ def expression(inp: ourlang.Expression) -> Statements: yield wasm.Statement(f'f64.{inp.operator}') return + if isinstance(inp.type, ourlang.OurTypeInt32): + if inp.operator == 'len': + if isinstance(inp.right.type, ourlang.OurTypeBytes): + yield wasm.Statement('i32.load') + return + raise NotImplementedError(expression, inp.type, inp.operator) if isinstance(inp, ourlang.FunctionCall): @@ -121,17 +136,33 @@ def expression(inp: ourlang.Expression) -> Statements: yield wasm.Statement('call', '${}'.format(inp.function.name)) return - if isinstance(inp, ourlang.AccessStructMember): - # FIXME: Properly implement this - # inp.type.render() is also a hack that doesn't really work consistently - if not isinstance(inp.type, ( - ourlang.OurTypeInt32, ourlang.OurTypeFloat32, - ourlang.OurTypeInt64, ourlang.OurTypeFloat64, - )): + if isinstance(inp, ourlang.AccessBytesIndex): + if not isinstance(inp.type, ourlang.OurTypeUInt8): raise NotImplementedError(inp, inp.type) + if not isinstance(inp.offset, int): + raise NotImplementedError(inp, inp.offset) + yield from expression(inp.varref) - yield wasm.Statement(inp.type.render() + '.load', 'offset=' + str(inp.member.offset)) + yield wasm.Statement('i32.const', str(inp.offset)) + yield wasm.Statement('call', '$___access_bytes_index___') + return + + if isinstance(inp, ourlang.AccessStructMember): + if isinstance(inp.member.type, ourlang.OurTypeUInt8): + mtyp = 'i32' + else: + # FIXME: Properly implement this + # inp.type.render() is also a hack that doesn't really work consistently + if not isinstance(inp.member.type, ( + ourlang.OurTypeInt32, ourlang.OurTypeFloat32, + ourlang.OurTypeInt64, ourlang.OurTypeFloat64, + )): + raise NotImplementedError + mtyp = inp.member.type.render() + + yield from expression(inp.varref) + yield wasm.Statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) return if isinstance(inp, ourlang.AccessTupleMember): @@ -221,7 +252,8 @@ def module(inp: ourlang.Module) -> wasm.Module: result = wasm.Module() result.functions = [ - _generate_allocator(inp), + _generate____new_reference___(inp), + _generate____access_bytes_index___(inp), ] + [ function(x) for x in inp.functions.values() @@ -229,7 +261,7 @@ def module(inp: ourlang.Module) -> wasm.Module: return result -def _generate_allocator(mod: ourlang.Module) -> wasm.Function: +def _generate____new_reference___(mod: ourlang.Module) -> wasm.Function: return wasm.Function( '___new_reference___', True, @@ -252,6 +284,38 @@ def _generate_allocator(mod: ourlang.Module) -> wasm.Function: ], ) +def _generate____access_bytes_index___(mod: ourlang.Module) -> wasm.Function: + return wasm.Function( + '___access_bytes_index___', + False, + [ + ('byt', type_(mod.types['i32']), ), + ('ofs', type_(mod.types['i32']), ), + ], + [ + ], + type_(mod.types['i32']), + [ + wasm.Statement('local.get', '$ofs'), + wasm.Statement('local.get', '$byt'), + wasm.Statement('i32.load'), + wasm.Statement('i32.lt_u'), + + wasm.Statement('if', comment='$ofs < len($byt)'), + wasm.Statement('local.get', '$byt'), + wasm.Statement('i32.const', '4', comment='Leading size field'), + wasm.Statement('i32.add'), + wasm.Statement('local.get', '$ofs'), + wasm.Statement('i32.add'), + wasm.Statement('i32.load8_u', comment='Within bounds'), + wasm.Statement('return'), + wasm.Statement('end'), + + wasm.Statement('i32.const', str(0), comment='Out of bounds'), + wasm.Statement('return'), + ], + ) + def _generate_tuple_constructor(inp: ourlang.TupleConstructor) -> Statements: yield wasm.Statement('i32.const', str(inp.tuple.alloc_size())) yield wasm.Statement('call', '$___new_reference___') @@ -280,16 +344,20 @@ def _generate_struct_constructor(inp: ourlang.StructConstructor) -> Statements: yield wasm.Statement('local.set', '$___new_reference___addr') for member in inp.struct.members: - # FIXME: Properly implement this - # inp.type.render() is also a hack that doesn't really work consistently - if not isinstance(member.type, ( - ourlang.OurTypeInt32, ourlang.OurTypeFloat32, - ourlang.OurTypeInt64, ourlang.OurTypeFloat64, - )): - raise NotImplementedError + if isinstance(member.type, ourlang.OurTypeUInt8): + mtyp = 'i32' + else: + # FIXME: Properly implement this + # inp.type.render() is also a hack that doesn't really work consistently + if not isinstance(member.type, ( + ourlang.OurTypeInt32, ourlang.OurTypeFloat32, + ourlang.OurTypeInt64, ourlang.OurTypeFloat64, + )): + raise NotImplementedError + mtyp = member.type.render() yield wasm.Statement('local.get', '$___new_reference___addr') yield wasm.Statement('local.get', f'${member.name}') - yield wasm.Statement(f'{member.type.render()}.store', 'offset=' + str(member.offset)) + yield wasm.Statement(f'{mtyp}.store', 'offset=' + str(member.offset)) yield wasm.Statement('local.get', '$___new_reference___addr') diff --git a/py2wasm/ourlang.py b/py2wasm/ourlang.py index ba87aec..49398f7 100644 --- a/py2wasm/ourlang.py +++ b/py2wasm/ourlang.py @@ -38,6 +38,18 @@ class OurTypeNone(OurType): def render(self) -> str: return 'None' +class OurTypeUInt8(OurType): + """ + The Integer type, unsigned and 8 bits wide + """ + __slots__ = () + + def render(self) -> str: + return 'u8' + + def alloc_size(self) -> int: + return 4 # Int32 under the hood + class OurTypeInt32(OurType): """ The Integer type, signed and 32 bits wide @@ -86,6 +98,15 @@ class OurTypeFloat64(OurType): def alloc_size(self) -> int: return 8 +class OurTypeBytes(OurType): + """ + The bytes type + """ + __slots__ = () + + def render(self) -> str: + return 'bytes' + class TupleMember: """ Represents a tuple member @@ -146,6 +167,21 @@ class Constant(Expression): """ __slots__ = () +class ConstantUInt8(Constant): + """ + An UInt8 constant value expression within a statement + """ + __slots__ = ('value', ) + + value: int + + def __init__(self, type_: OurTypeUInt8, value: int) -> None: + super().__init__(type_) + self.value = value + + def render(self) -> str: + return str(self.value) + class ConstantInt32(Constant): """ An Int32 constant value expression within a statement @@ -257,7 +293,7 @@ class UnaryOp(Expression): self.right = right def render(self) -> str: - if self.operator in WEBASSEMBLY_BUILDIN_FLOAT_OPS: + if self.operator in WEBASSEMBLY_BUILDIN_FLOAT_OPS or self.operator == 'len': return f'{self.operator}({self.right.render()})' return f'{self.operator}{self.right.render()}' @@ -291,6 +327,24 @@ class FunctionCall(Expression): return f'{self.function.name}({args})' +class AccessBytesIndex(Expression): + """ + Access a bytes index for reading + """ + __slots__ = ('varref', 'offset', ) + + varref: VariableReference + offset: int + + def __init__(self, type_: OurType, varref: VariableReference, offset: int) -> None: + super().__init__(type_) + + self.varref = varref + self.offset = offset + + def render(self) -> str: + return f'{self.varref.render()}[{self.offset}]' + class AccessStructMember(Expression): """ Access a struct member for reading of writing @@ -554,10 +608,12 @@ class Module: def __init__(self) -> None: self.types = { + 'u8': OurTypeUInt8(), 'i32': OurTypeInt32(), 'i64': OurTypeInt64(), 'f32': OurTypeFloat32(), 'f64': OurTypeFloat64(), + 'bytes': OurTypeBytes(), } self.functions = {} self.structs = {} @@ -900,6 +956,18 @@ class OurVisitor: 'sqrt', self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), ) + elif node.func.id == 'len': + if not isinstance(exp_type, OurTypeInt32): + _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') + + if 1 != len(node.args): + _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') + + return UnaryOp( + exp_type, + 'len', + self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]), + ) else: if node.func.id not in module.functions: _raise_static_error(node, 'Call to undefined function') @@ -966,6 +1034,13 @@ class OurVisitor: _raise_static_error(node, f'Undefined variable {node.value.id}') node_typ = our_locals[node.value.id] + if isinstance(node_typ, OurTypeBytes): + return AccessBytesIndex( + module.types['u8'], + VariableReference(node_typ, node.value.id), + idx, + ) + if not isinstance(node_typ, OurTypeTuple): _raise_static_error(node, f'Cannot take index of non-tuple {node.value.id}') @@ -987,6 +1062,14 @@ class OurVisitor: _not_implemented(node.kind is None, 'Constant.kind') + if isinstance(exp_type, OurTypeUInt8): + if not isinstance(node.value, int): + _raise_static_error(node, 'Expected integer value') + + # FIXME: Range check + + return ConstantUInt8(exp_type, node.value) + if isinstance(exp_type, OurTypeInt32): if not isinstance(node.value, int): _raise_static_error(node, 'Expected integer value') diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py index d9f299e..a617c26 100644 --- a/tests/integration/helpers.py +++ b/tests/integration/helpers.py @@ -99,6 +99,15 @@ def _run_pywasm(code_wasm, args): runtime = pywasm.Runtime(module, result.make_imports(), {}) + def set_byte(idx, byt): + runtime.store.mems[0].data[idx] = byt + + args = _convert_bytes_arguments( + args, + lambda x: runtime.exec('___new_reference___', [x]), + set_byte + ) + sys.stderr.write(f'{DASHES} Memory (pre run) {DASHES}\n') _dump_memory(runtime.store.mems[0].data) @@ -120,13 +129,22 @@ def _run_pywasm3(code_wasm, args): rtime = env.new_runtime(1024 * 1024) rtime.load(mod) - # sys.stderr.write(f'{DASHES} Memory (pre run) {DASHES}\n') - # _dump_memory(rtime.get_memory(0).tobytes()) + def set_byte(idx, byt): + rtime.get_memory(0)[idx] = byt + + args = _convert_bytes_arguments( + args, + rtime.find_function('___new_reference___'), + set_byte + ) + + sys.stderr.write(f'{DASHES} Memory (pre run) {DASHES}\n') + _dump_memory(rtime.get_memory(0)) result.returned_value = rtime.find_function('testEntry')(*args) - # sys.stderr.write(f'{DASHES} Memory (post run) {DASHES}\n') - # _dump_memory(rtime.get_memory(0).tobytes()) + sys.stderr.write(f'{DASHES} Memory (post run) {DASHES}\n') + _dump_memory(rtime.get_memory(0)) return result @@ -173,6 +191,27 @@ def _run_wasmer(code_wasm, args): return result +def _convert_bytes_arguments(args, new_reference, set_byte): + result = [] + for arg in args: + if not isinstance(arg, bytes): + result.append(arg) + continue + + # TODO: Implement and use the bytes constructor function + offset = new_reference(len(arg) + 4) + result.append(offset) + + # Store the length prefix + for idx, byt in enumerate(len(arg).to_bytes(4, byteorder='little')): + set_byte(offset + idx, byt) + + # Store the actual bytes + for idx, byt in enumerate(arg): + set_byte(offset + 4 + idx, byt) + + return result + def _dump_memory(mem): line_width = 16 diff --git a/tests/integration/test_runtime_checks.py b/tests/integration/test_runtime_checks.py new file mode 100644 index 0000000..abe7081 --- /dev/null +++ b/tests/integration/test_runtime_checks.py @@ -0,0 +1,15 @@ +import pytest + +from .helpers import Suite + +@pytest.mark.integration_test +def test_bytes_index_out_of_bounds(): + code_py = """ +@exported +def testEntry(f: bytes) -> u8: + return f[50] +""" + + result = Suite(code_py, 'test_call').run_code(b'Short', b'Long' * 100) + + assert 0 == result.returned_value diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index ea281e7..f5f28f9 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -3,6 +3,7 @@ import pytest from .helpers import Suite TYPE_MAP = { + 'u8': int, 'i32': int, 'i64': int, 'f32': float, @@ -10,7 +11,7 @@ TYPE_MAP = { } @pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) +@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64', 'u8']) def test_return(type_): code_py = f""" @exported @@ -66,7 +67,7 @@ def testEntry() -> {type_}: assert TYPE_MAP[type_] == type(result.returned_value) @pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) +@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64', 'u8']) def test_arg(type_): code_py = f""" @exported @@ -273,22 +274,23 @@ def testEntry() -> i32: assert [] == result.log_int32_list @pytest.mark.integration_test -def test_struct_0(): - code_py = """ +@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64', 'u8']) +def test_struct_0(type_): + code_py = f""" class CheckedValue: - value: i32 + value: {type_} @exported -def testEntry() -> i32: - return helper(CheckedValue(2345)) +def testEntry() -> {type_}: + return helper(CheckedValue(23)) -def helper(cv: CheckedValue) -> i32: +def helper(cv: CheckedValue) -> {type_}: return cv.value """ result = Suite(code_py, 'test_call').run_code() - assert 2345 == result.returned_value + assert 23 == result.returned_value assert [] == result.log_int32_list @pytest.mark.integration_test @@ -366,6 +368,42 @@ def helper(v: (f32, f32, f32, )) -> f32: assert 3.74 < result.returned_value < 3.75 assert [] == result.log_int32_list +@pytest.mark.integration_test +def test_bytes_address(): + code_py = """ +@exported +def testEntry(f: bytes) -> bytes: + return f +""" + + result = Suite(code_py, 'test_call').run_code(b'This is a test') + + assert 4 == result.returned_value + +@pytest.mark.integration_test +def test_bytes_length(): + code_py = """ +@exported +def testEntry(f: bytes) -> i32: + return len(f) +""" + + result = Suite(code_py, 'test_call').run_code(b'This is another test') + + assert 20 == result.returned_value + +@pytest.mark.integration_test +def test_bytes_index(): + code_py = """ +@exported +def testEntry(f: bytes) -> u8: + return f[8] +""" + + result = Suite(code_py, 'test_call').run_code(b'This is another test') + + assert 0x61 == result.returned_value + @pytest.mark.integration_test @pytest.mark.skip('SIMD support is but a dream') def test_tuple_i32x4():