Starting on platform runtime
This commit is contained in:
parent
d1b593d4e5
commit
b815b12bb2
0
phasm/transformers/__init__.py
Normal file
0
phasm/transformers/__init__.py
Normal file
0
phasm/transformers/wasm/__init__.py
Normal file
0
phasm/transformers/wasm/__init__.py
Normal file
6
phasm/transformers/wasm/phasmplatform/__init__.py
Normal file
6
phasm/transformers/wasm/phasmplatform/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from . import splitoncall
|
||||||
|
|
||||||
|
from phasm.wasm import Module
|
||||||
|
|
||||||
|
def transform(module: Module) -> None:
|
||||||
|
splitoncall.transform(module)
|
||||||
56
phasm/transformers/wasm/phasmplatform/splitoncall.py
Normal file
56
phasm/transformers/wasm/phasmplatform/splitoncall.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from phasm.wasm import (
|
||||||
|
Function, Module, Statement
|
||||||
|
)
|
||||||
|
|
||||||
|
def split_on_call_function(module: Module, function: Function) -> List[Function]:
|
||||||
|
function_list = []
|
||||||
|
statement_list = []
|
||||||
|
idx = 0
|
||||||
|
while function.statements:
|
||||||
|
stmt = function.statements.pop(0)
|
||||||
|
if not stmt.name == 'call':
|
||||||
|
statement_list.append(stmt)
|
||||||
|
continue
|
||||||
|
|
||||||
|
function_list.append(Function(
|
||||||
|
f'{function.name}.{idx}',
|
||||||
|
None,
|
||||||
|
function.params,
|
||||||
|
function.locals,
|
||||||
|
function.result,
|
||||||
|
statement_list + [stmt] + [Statement('call', f'${function.name}.{idx + 1}')]
|
||||||
|
))
|
||||||
|
statement_list = []
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
function_list.append(Function(
|
||||||
|
f'{function.name}.{idx}',
|
||||||
|
None,
|
||||||
|
function.params,
|
||||||
|
function.locals,
|
||||||
|
function.result,
|
||||||
|
statement_list
|
||||||
|
))
|
||||||
|
|
||||||
|
if function.exported_name:
|
||||||
|
function_list.append(Function(
|
||||||
|
function.name + '.e',
|
||||||
|
function.exported_name,
|
||||||
|
function.params,
|
||||||
|
[],
|
||||||
|
function.result,
|
||||||
|
[
|
||||||
|
Statement('local.get', '$' + x[0])
|
||||||
|
for x in function.params
|
||||||
|
] + [Statement('call', f'${function.name}.0')]
|
||||||
|
))
|
||||||
|
|
||||||
|
return function_list
|
||||||
|
|
||||||
|
def transform(module: Module) -> None:
|
||||||
|
new_functions = []
|
||||||
|
for func in module.functions:
|
||||||
|
new_functions.extend(split_on_call_function(module, func))
|
||||||
|
module.functions = new_functions
|
||||||
@ -15,6 +15,12 @@ class WatSerializable:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError(self, 'to_wat')
|
raise NotImplementedError(self, 'to_wat')
|
||||||
|
|
||||||
|
def alloc_size(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns how many bytes a variable of this type takes up in memory
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(self, 'alloc_size')
|
||||||
|
|
||||||
class WasmType(WatSerializable):
|
class WasmType(WatSerializable):
|
||||||
"""
|
"""
|
||||||
Type base class
|
Type base class
|
||||||
@ -36,6 +42,9 @@ class WasmTypeInt32(WasmType):
|
|||||||
def to_wat(self) -> str:
|
def to_wat(self) -> str:
|
||||||
return 'i32'
|
return 'i32'
|
||||||
|
|
||||||
|
def alloc_size(self) -> int:
|
||||||
|
return 4
|
||||||
|
|
||||||
class WasmTypeInt64(WasmType):
|
class WasmTypeInt64(WasmType):
|
||||||
"""
|
"""
|
||||||
i64 value
|
i64 value
|
||||||
@ -52,6 +61,9 @@ class WasmTypeFloat32(WasmType):
|
|||||||
def to_wat(self) -> str:
|
def to_wat(self) -> str:
|
||||||
return 'f32'
|
return 'f32'
|
||||||
|
|
||||||
|
def alloc_size(self) -> int:
|
||||||
|
return 4
|
||||||
|
|
||||||
class WasmTypeFloat64(WasmType):
|
class WasmTypeFloat64(WasmType):
|
||||||
"""
|
"""
|
||||||
f64 value
|
f64 value
|
||||||
@ -179,6 +191,34 @@ class ModuleMemory(WatSerializable):
|
|||||||
'(export "memory" (memory 0))\n'
|
'(export "memory" (memory 0))\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class TableElement(WatSerializable):
|
||||||
|
"""
|
||||||
|
Represents a Web Assembly table element
|
||||||
|
"""
|
||||||
|
def __init__(self, offset: int, args: Iterable[str]) -> None:
|
||||||
|
self.offset = offset
|
||||||
|
self.args = [*args]
|
||||||
|
|
||||||
|
def to_wat(self) -> str:
|
||||||
|
args = ' '.join(self.args)
|
||||||
|
|
||||||
|
return f'(elem (i32.const {self.offset}) {args})'
|
||||||
|
|
||||||
|
class Table(WatSerializable):
|
||||||
|
"""
|
||||||
|
Represents a Web Assembly table
|
||||||
|
"""
|
||||||
|
def __init__(self, size: int, typ: str, elements: List[TableElement]) -> None:
|
||||||
|
self.size = size
|
||||||
|
self.type = typ
|
||||||
|
self.elements = [*elements]
|
||||||
|
|
||||||
|
def to_wat(self) -> str:
|
||||||
|
return (
|
||||||
|
f'(table {self.size} {self.type})\n '
|
||||||
|
+ '\n '.join(x.to_wat() for x in self.elements)
|
||||||
|
)
|
||||||
|
|
||||||
class Module(WatSerializable):
|
class Module(WatSerializable):
|
||||||
"""
|
"""
|
||||||
Represents a Web Assembly module
|
Represents a Web Assembly module
|
||||||
@ -187,13 +227,15 @@ class Module(WatSerializable):
|
|||||||
self.imports: List[Import] = []
|
self.imports: List[Import] = []
|
||||||
self.functions: List[Function] = []
|
self.functions: List[Function] = []
|
||||||
self.memory = ModuleMemory()
|
self.memory = ModuleMemory()
|
||||||
|
self.tables: List[Table] = []
|
||||||
|
|
||||||
def to_wat(self) -> str:
|
def to_wat(self) -> str:
|
||||||
"""
|
"""
|
||||||
Generates the text version
|
Generates the text version
|
||||||
"""
|
"""
|
||||||
return '(module\n {}\n {}\n {})\n'.format(
|
return '(module\n {}\n {}\n {}\n {})\n'.format(
|
||||||
'\n '.join(x.to_wat() for x in self.imports),
|
'\n '.join(x.to_wat() for x in self.imports),
|
||||||
self.memory.to_wat(),
|
self.memory.to_wat(),
|
||||||
|
'\n '.join(x.to_wat() for x in self.tables),
|
||||||
'\n '.join(x.to_wat() for x in self.functions),
|
'\n '.join(x.to_wat() for x in self.functions),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,6 +3,9 @@ from typing import Any, Callable
|
|||||||
class Module:
|
class Module:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def link_function(self, module_name: str, function_name: str, function: Callable[[Any], Any]) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
class Runtime:
|
class Runtime:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|||||||
@ -152,12 +152,13 @@ class RunnerPywasm3(RunnerBase):
|
|||||||
self.rtime = self.env.new_runtime(1024 * 1024)
|
self.rtime = self.env.new_runtime(1024 * 1024)
|
||||||
|
|
||||||
def interpreter_load(self, imports: Optional[Dict[str, Callable[[Any], Any]]] = None) -> None:
|
def interpreter_load(self, imports: Optional[Dict[str, Callable[[Any], Any]]] = None) -> None:
|
||||||
if imports is not None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
self.mod = self.env.parse_module(self.wasm_bin)
|
self.mod = self.env.parse_module(self.wasm_bin)
|
||||||
self.rtime.load(self.mod)
|
self.rtime.load(self.mod)
|
||||||
|
|
||||||
|
if imports is not None:
|
||||||
|
for key, val in imports.items():
|
||||||
|
self.mod.link_function('imports', key, val)
|
||||||
|
|
||||||
def interpreter_write_memory(self, offset: int, data: Iterable[int]) -> None:
|
def interpreter_write_memory(self, offset: int, data: Iterable[int]) -> None:
|
||||||
memory = self.rtime.get_memory(0)
|
memory = self.rtime.get_memory(0)
|
||||||
|
|
||||||
|
|||||||
0
tests/integration/test_transformers/__init__.py
Normal file
0
tests/integration/test_transformers/__init__.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
from phasm import wasm
|
||||||
|
|
||||||
|
from ...runners import RunnerPywasm, RunnerPywasm3, RunnerWasmtime, RunnerWasmer
|
||||||
|
|
||||||
|
from phasm.transformers.wasm import phasmplatform as sut
|
||||||
|
|
||||||
|
none = wasm.WasmTypeNone()
|
||||||
|
f32 = wasm.WasmTypeFloat32()
|
||||||
|
|
||||||
|
def run(module, func, *, imports = None):
|
||||||
|
# runner = RunnerPywasm('-')
|
||||||
|
runner = RunnerPywasm3('-')
|
||||||
|
# runner = RunnerWasmtime('-')
|
||||||
|
runner.wasm_ast = module
|
||||||
|
runner.compile_wat()
|
||||||
|
runner.dump_wasm_wat(sys.stderr)
|
||||||
|
runner.compile_wasm()
|
||||||
|
runner.interpreter_setup()
|
||||||
|
runner.interpreter_load(imports)
|
||||||
|
return runner.call(func)
|
||||||
|
|
||||||
|
|
||||||
|
def return_f32_30_0() -> float:
|
||||||
|
return 30.0
|
||||||
|
|
||||||
|
|
||||||
|
def make_return_f32_30_0_import():
|
||||||
|
return wasm.Import('imports', 'return_f32_30_0', 'return_f32_30_0', [], f32)
|
||||||
|
|
||||||
|
|
||||||
|
def make_func1():
|
||||||
|
"""
|
||||||
|
func1: No params, just a return value
|
||||||
|
|
||||||
|
In Haskell:
|
||||||
|
func1 :: f32
|
||||||
|
func1 = 30
|
||||||
|
"""
|
||||||
|
return wasm.Function('func1', 'func1', [], [], f32, [
|
||||||
|
wasm.Statement('f32.const', '30'),
|
||||||
|
wasm.Statement('return'),
|
||||||
|
])
|
||||||
|
|
||||||
|
def make_func2():
|
||||||
|
"""
|
||||||
|
func2: Calls imported method
|
||||||
|
|
||||||
|
In Haskell:
|
||||||
|
return_f32_30_0 :: IO(f32)
|
||||||
|
|
||||||
|
func2 :: IO(f32)
|
||||||
|
func2 = return_f32_30_0
|
||||||
|
"""
|
||||||
|
return wasm.Function('func2', 'func2', [], [], f32, [
|
||||||
|
wasm.Statement('call', '$return_f32_30_0'),
|
||||||
|
wasm.Statement('return'),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def test_const_return_func():
|
||||||
|
module = wasm.Module()
|
||||||
|
module.functions.append(make_func1())
|
||||||
|
|
||||||
|
assert 30 == run(module, 'func1')
|
||||||
|
|
||||||
|
sut.transform(module)
|
||||||
|
|
||||||
|
assert 30 == run(module, 'func1')
|
||||||
|
|
||||||
|
def test_call_imported_method():
|
||||||
|
module = wasm.Module()
|
||||||
|
module.imports.append(make_return_f32_30_0_import())
|
||||||
|
module.functions.append(make_func2())
|
||||||
|
|
||||||
|
assert 30 == run(module, 'func2', imports={'return_f32_30_0': return_f32_30_0})
|
||||||
|
|
||||||
|
sut.transform(module)
|
||||||
|
|
||||||
|
assert 30 == run(module, 'func2', imports={'return_f32_30_0': return_f32_30_0})
|
||||||
|
|
||||||
|
def test_parameters():
|
||||||
|
assert 0
|
||||||
|
|
||||||
|
# func1: No params, no return value
|
||||||
|
# func1 = func4 10 20
|
||||||
|
func1_statements = [
|
||||||
|
wasm.Statement('f32.const', '30'),
|
||||||
|
wasm.Statement('f32.const', '8'),
|
||||||
|
wasm.Statement('call', '$func3'),
|
||||||
|
wasm.Statement('f32.const', '0'),
|
||||||
|
wasm.Statement('call', '$log'),
|
||||||
|
]
|
||||||
|
func1 = wasm.Function('func1', 'func1', [], [], none, func1_statements)
|
||||||
|
|
||||||
|
|
||||||
|
# func2: No params, with return value
|
||||||
|
# func2 = remote_value_get
|
||||||
|
func2_statements = [
|
||||||
|
wasm.Statement('call', '$remote_value_get'),
|
||||||
|
]
|
||||||
|
func2 = wasm.Function('func2', 'func2', [], [], f32, func2_statements)
|
||||||
|
|
||||||
|
# With params, no return value
|
||||||
|
# func3 left right = (func4 left right) + func2
|
||||||
|
func3_statements = [
|
||||||
|
wasm.Statement('local.get', '$left'),
|
||||||
|
wasm.Statement('local.get', '$right'),
|
||||||
|
wasm.Statement('call', '$func4'),
|
||||||
|
wasm.Statement('call', '$func2'),
|
||||||
|
wasm.Statement('f32.add'),
|
||||||
|
wasm.Statement('call', '$log'),
|
||||||
|
]
|
||||||
|
func3 = wasm.Function('func3', 'func3', [
|
||||||
|
('left', f32, ),
|
||||||
|
('right', f32, ),
|
||||||
|
], [], none, func3_statements)
|
||||||
|
|
||||||
|
# func4: With params, and return value
|
||||||
|
# func4 l r = l * 2 + r
|
||||||
|
func4_statements = [
|
||||||
|
wasm.Statement('local.get', '$left'),
|
||||||
|
wasm.Statement('f32.const', '2'),
|
||||||
|
wasm.Statement('f32.mul'),
|
||||||
|
wasm.Statement('local.get', '$right'),
|
||||||
|
wasm.Statement('f32.add'),
|
||||||
|
wasm.Statement('return'),
|
||||||
|
]
|
||||||
|
func4 = wasm.Function('func4', None, [
|
||||||
|
('left', f32, ),
|
||||||
|
('right', f32, ),
|
||||||
|
], [], f32, func4_statements)
|
||||||
|
|
||||||
|
module = wasm.Module()
|
||||||
|
module.imports.append(wasm.Import('imports', 'log', 'log', [('a', f32, )], none))
|
||||||
|
module.imports.append(wasm.Import('imports', 'remote_value_get', 'remote_value_get', [], f32))
|
||||||
|
module.functions.append(func1)
|
||||||
|
module.functions.append(func2)
|
||||||
|
module.functions.append(func3)
|
||||||
|
module.functions.append(func4)
|
||||||
|
|
||||||
|
def my_remote_value_get() -> float:
|
||||||
|
return 19.0
|
||||||
|
|
||||||
|
log = []
|
||||||
|
def my_log(a: float) -> None:
|
||||||
|
log.append(a)
|
||||||
|
|
||||||
|
imports = {'log': my_log, 'remote_value_get': my_remote_value_get}
|
||||||
|
|
||||||
|
run(module, 'func1', imports=imports)
|
||||||
|
|
||||||
|
assert [87, 0] == log
|
||||||
|
|
||||||
|
sut.transform(module)
|
||||||
|
|
||||||
|
run(module, 'func1', imports=imports)
|
||||||
|
|
||||||
|
assert 0
|
||||||
|
|
||||||
|
assert [87, 0] == log
|
||||||
|
|
||||||
|
# def test_locals():
|
||||||
Loading…
x
Reference in New Issue
Block a user