MVP #1

Merged
jbwdevries merged 73 commits from idea_crc32 into master 2022-08-21 12:59:21 +00:00
3 changed files with 222 additions and 159 deletions
Showing only changes of commit c16eb86e10 - Show all commits

View File

@ -12,7 +12,7 @@ from . import wasm
StatementGenerator = Generator[wasm.Statement, None, None] StatementGenerator = Generator[wasm.Statement, None, None]
WLocals = Dict[str, str] WLocals = Dict[str, wasm.OurType]
class Visitor: class Visitor:
""" """
@ -21,6 +21,29 @@ class Visitor:
Since we need to visit the whole tree, there's no point in subclassing Since we need to visit the whole tree, there's no point in subclassing
the buildin visitor the buildin visitor
""" """
def __init__(self) -> None:
self._type_map: WLocals = {
'None': wasm.OurTypeNone(),
'bool': wasm.OurTypeBool(),
'i32': wasm.OurTypeInt32(),
'i64': wasm.OurTypeInt64(),
'f32': wasm.OurTypeFloat32(),
'f64': wasm.OurTypeFloat64(),
}
def _get_type(self, name: Union[None, str, ast.expr]) -> wasm.OurType:
if name is None:
name = 'None'
if isinstance(name, ast.expr):
if isinstance(name, ast.Name):
name = name.id
else:
raise NotImplementedError(f'_get_type(ast.{type(name).__name__})')
return self._type_map[name]
def visit_Module(self, node: ast.Module) -> wasm.Module: def visit_Module(self, node: ast.Module) -> wasm.Module:
""" """
Visits a Python module, which results in a wasm Module Visits a Python module, which results in a wasm Module
@ -33,12 +56,12 @@ class Visitor:
'___new_reference___', '___new_reference___',
False, False,
[ [
('alloc_size', 'i32'), ('alloc_size', self._get_type('i32')),
], ],
[ [
('result', 'i32'), ('result', self._get_type('i32')),
], ],
'i32', self._get_type('i32'),
[ [
wasm.Statement('i32.const', '0'), wasm.Statement('i32.const', '0'),
wasm.Statement('i32.const', '0'), wasm.Statement('i32.const', '0'),
@ -69,7 +92,7 @@ class Visitor:
if isinstance(stmt, ast.ClassDef): if isinstance(stmt, ast.ClassDef):
wclass = self.pre_visit_ClassDef(module, stmt) wclass = self.pre_visit_ClassDef(module, stmt)
module.classes.append(wclass) self._type_map[wclass.name] = wclass
continue continue
# No other pre visits to do # No other pre visits to do
@ -122,13 +145,13 @@ class Visitor:
assert not call.keywords assert not call.keywords
mod, name, intname = _parse_import_decorator(node.name, call.args) mod, name, intname = self._parse_import_decorator(node.name, call.args)
return wasm.Import(mod, name, intname, _parse_import_args(node.args)) return wasm.Import(mod, name, intname, self._parse_import_args(node.args))
else: else:
raise NotImplementedError raise NotImplementedError
result = None if node.returns is None else _parse_annotation(node.returns) result = None if node.returns is None else self._get_type(node.returns)
assert not node.args.vararg assert not node.args.vararg
assert not node.args.kwonlyargs assert not node.args.kwonlyargs
@ -136,23 +159,12 @@ class Visitor:
assert not node.args.kwarg assert not node.args.kwarg
assert not node.args.defaults assert not node.args.defaults
class_lookup = { params: List[wasm.Param] = []
x.name: x
for x in module.classes
}
params = []
for arg in [*node.args.posonlyargs, *node.args.args]: for arg in [*node.args.posonlyargs, *node.args.args]:
if not isinstance(arg.annotation, ast.Name): params.append((arg.arg, self._get_type(arg.annotation), ))
raise NotImplementedError
if arg.annotation.id in class_lookup: locals_: List[wasm.Param] = [
params.append((arg.arg, 'i32', )) ('___new_reference___addr', self._get_type('i32')), # For the ___new_reference__ method
else:
params.append((arg.arg, _parse_annotation(arg.annotation), ))
locals_ = [
('___new_reference___addr', 'i32'), # For the ___new_reference__ method
] ]
return wasm.Function(node.name, exported, params, locals_, result, []) return wasm.Function(node.name, exported, params, locals_, result, [])
@ -180,7 +192,7 @@ class Visitor:
self, self,
module: wasm.Module, module: wasm.Module,
node: ast.ClassDef, node: ast.ClassDef,
) -> wasm.Class: ) -> wasm.OurTypeClass:
""" """
TODO: Document this TODO: Document this
""" """
@ -200,7 +212,7 @@ class Visitor:
raise NotImplementedError raise NotImplementedError
if not isinstance(stmt.annotation, ast.Name): if not isinstance(stmt.annotation, ast.Name):
raise NotImplementedError raise NotImplementedError('Cannot recurse classes yet')
if stmt.annotation.id != 'i32': if stmt.annotation.id != 'i32':
raise NotImplementedError raise NotImplementedError
@ -217,13 +229,13 @@ class Visitor:
default = wasm.Constant(stmt.value.value) default = wasm.Constant(stmt.value.value)
member = wasm.ClassMember( member = wasm.ClassMember(
stmt.target.id, stmt.annotation.id, offset, default stmt.target.id, self._get_type(stmt.annotation), offset, default
) )
members.append(member) members.append(member)
offset += member.alloc_size() offset += member.type.alloc_size()
return wasm.Class(node.name, members) return wasm.OurTypeClass(node.name, members)
def visit_stmt( def visit_stmt(
self, self,
@ -241,7 +253,7 @@ class Visitor:
if isinstance(stmt, ast.Expr): if isinstance(stmt, ast.Expr):
assert isinstance(stmt.value, ast.Call) assert isinstance(stmt.value, ast.Call)
return self.visit_Call(module, wlocals, "None", stmt.value) return self.visit_Call(module, wlocals, self._get_type('None'), stmt.value)
if isinstance(stmt, ast.If): if isinstance(stmt, ast.If):
return self.visit_If(module, func, wlocals, stmt) return self.visit_If(module, func, wlocals, stmt)
@ -261,7 +273,7 @@ class Visitor:
assert stmt.value is not None assert stmt.value is not None
return_type = _parse_annotation(func.returns) return_type = self._get_type(func.returns)
yield from self.visit_expr(module, wlocals, return_type, stmt.value) yield from self.visit_expr(module, wlocals, return_type, stmt.value)
yield wasm.Statement('return') yield wasm.Statement('return')
@ -279,7 +291,7 @@ class Visitor:
yield from self.visit_expr( yield from self.visit_expr(
module, module,
wlocals, wlocals,
'bool', self._get_type('bool'),
stmt.test, stmt.test,
) )
@ -299,7 +311,7 @@ class Visitor:
self, self,
module: wasm.Module, module: wasm.Module,
wlocals: WLocals, wlocals: WLocals,
exp_type: str, exp_type: wasm.OurType,
node: ast.expr, node: ast.expr,
) -> StatementGenerator: ) -> StatementGenerator:
""" """
@ -332,7 +344,7 @@ class Visitor:
self, self,
module: wasm.Module, module: wasm.Module,
wlocals: WLocals, wlocals: WLocals,
exp_type: str, exp_type: wasm.OurType,
node: ast.UnaryOp, node: ast.UnaryOp,
) -> StatementGenerator: ) -> StatementGenerator:
""" """
@ -356,7 +368,7 @@ class Visitor:
self, self,
module: wasm.Module, module: wasm.Module,
wlocals: WLocals, wlocals: WLocals,
exp_type: str, exp_type: wasm.OurType,
node: ast.BinOp, node: ast.BinOp,
) -> StatementGenerator: ) -> StatementGenerator:
""" """
@ -366,11 +378,11 @@ class Visitor:
yield from self.visit_expr(module, wlocals, exp_type, node.right) yield from self.visit_expr(module, wlocals, exp_type, node.right)
if isinstance(node.op, ast.Add): if isinstance(node.op, ast.Add):
yield wasm.Statement('{}.add'.format(exp_type)) yield wasm.Statement('{}.add'.format(exp_type.to_wasm()))
return return
if isinstance(node.op, ast.Sub): if isinstance(node.op, ast.Sub):
yield wasm.Statement('{}.sub'.format(exp_type)) yield wasm.Statement('{}.sub'.format(exp_type.to_wasm()))
return return
raise NotImplementedError(node.op) raise NotImplementedError(node.op)
@ -379,19 +391,19 @@ class Visitor:
self, self,
module: wasm.Module, module: wasm.Module,
wlocals: WLocals, wlocals: WLocals,
exp_type: str, exp_type: wasm.OurType,
node: ast.Compare, node: ast.Compare,
) -> StatementGenerator: ) -> StatementGenerator:
""" """
Visits a Compare node as (part of) an expression Visits a Compare node as (part of) an expression
""" """
assert 'bool' == exp_type assert isinstance(exp_type, wasm.OurTypeBool)
if 1 != len(node.ops) or 1 != len(node.comparators): if 1 != len(node.ops) or 1 != len(node.comparators):
raise NotImplementedError raise NotImplementedError
yield from self.visit_expr(module, wlocals, 'i32', node.left) yield from self.visit_expr(module, wlocals, self._get_type('i32'), node.left)
yield from self.visit_expr(module, wlocals, 'i32', node.comparators[0]) yield from self.visit_expr(module, wlocals, self._get_type('i32'), node.comparators[0])
if isinstance(node.ops[0], ast.Lt): if isinstance(node.ops[0], ast.Lt):
yield wasm.Statement('i32.lt_s') yield wasm.Statement('i32.lt_s')
@ -411,7 +423,7 @@ class Visitor:
self, self,
module: wasm.Module, module: wasm.Module,
wlocals: WLocals, wlocals: WLocals,
exp_type: str, exp_type: wasm.OurType,
node: ast.Call, node: ast.Call,
) -> StatementGenerator: ) -> StatementGenerator:
""" """
@ -422,11 +434,15 @@ class Visitor:
called_name = node.func.id called_name = node.func.id
search_list: List[Union[wasm.Function, wasm.Import, wasm.Class]] if called_name in self._type_map:
klass = self._type_map[called_name]
if isinstance(klass, wasm.OurTypeClass):
return self.visit_Call_class(module, wlocals, exp_type, node, klass)
search_list: List[Union[wasm.Function, wasm.Import]]
search_list = [ search_list = [
*module.functions, *module.functions,
*module.imports, *module.imports,
*module.classes,
] ]
called_func_list = [ called_func_list = [
@ -438,18 +454,15 @@ class Visitor:
assert 1 == len(called_func_list), \ assert 1 == len(called_func_list), \
'Could not find function {}'.format(node.func.id) 'Could not find function {}'.format(node.func.id)
if isinstance(called_func_list[0], wasm.Class):
return self.visit_Call_class(module, wlocals, exp_type, node, called_func_list[0])
return self.visit_Call_func(module, wlocals, exp_type, node, called_func_list[0]) return self.visit_Call_func(module, wlocals, exp_type, node, called_func_list[0])
def visit_Call_class( def visit_Call_class(
self, self,
module: wasm.Module, module: wasm.Module,
wlocals: WLocals, wlocals: WLocals,
exp_type: str, exp_type: wasm.OurType,
node: ast.Call, node: ast.Call,
cls: wasm.Class, cls: wasm.OurTypeClass,
) -> StatementGenerator: ) -> StatementGenerator:
""" """
Visits a Call node as (part of) an expression Visits a Call node as (part of) an expression
@ -469,8 +482,8 @@ class Visitor:
raise NotImplementedError raise NotImplementedError
yield wasm.Statement('local.get', '$___new_reference___addr') yield wasm.Statement('local.get', '$___new_reference___addr')
yield wasm.Statement(f'{member.type}.const', str(arg.value)) yield wasm.Statement(f'{member.type.to_wasm()}.const', str(arg.value))
yield wasm.Statement(f'{member.type}.store', 'offset=' + str(member.offset)) yield wasm.Statement(f'{member.type.to_wasm()}.store', 'offset=' + str(member.offset))
yield wasm.Statement('local.get', '$___new_reference___addr') yield wasm.Statement('local.get', '$___new_reference___addr')
@ -478,7 +491,7 @@ class Visitor:
self, self,
module: wasm.Module, module: wasm.Module,
wlocals: WLocals, wlocals: WLocals,
exp_type: str, exp_type: wasm.OurType,
node: ast.Call, node: ast.Call,
func: Union[wasm.Function, wasm.Import], func: Union[wasm.Function, wasm.Import],
) -> StatementGenerator: ) -> StatementGenerator:
@ -491,7 +504,7 @@ class Visitor:
called_params = func.params called_params = func.params
called_result = func.result called_result = func.result
assert exp_type == called_result assert exp_type == called_result, 'Function does not match expected type'
assert len(called_params) == len(node.args), \ assert len(called_params) == len(node.args), \
'{}:{} Function {} requires {} arguments, but {} are supplied'.format( '{}:{} Function {} requires {} arguments, but {} are supplied'.format(
@ -507,32 +520,32 @@ class Visitor:
'${}'.format(called_name), '${}'.format(called_name),
) )
def visit_Constant(self, exp_type: str, node: ast.Constant) -> StatementGenerator: def visit_Constant(self, exp_type: wasm.OurType, node: ast.Constant) -> StatementGenerator:
""" """
Visits a Constant node as (part of) an expression Visits a Constant node as (part of) an expression
""" """
if 'i32' == exp_type: if isinstance(exp_type, wasm.OurTypeInt32):
assert isinstance(node.value, int) assert isinstance(node.value, int)
assert -2147483648 <= node.value <= 2147483647 assert -2147483648 <= node.value <= 2147483647
yield wasm.Statement('i32.const', str(node.value)) yield wasm.Statement('i32.const', str(node.value))
return return
if 'i64' == exp_type: if isinstance(exp_type, wasm.OurTypeInt64):
assert isinstance(node.value, int) assert isinstance(node.value, int)
assert -9223372036854775808 <= node.value <= 9223372036854775807 assert -9223372036854775808 <= node.value <= 9223372036854775807
yield wasm.Statement('i64.const', str(node.value)) yield wasm.Statement('i64.const', str(node.value))
return return
if 'f32' == exp_type: if isinstance(exp_type, wasm.OurTypeFloat32):
assert isinstance(node.value, float) assert isinstance(node.value, float)
# TODO: Size check? # TODO: Size check?
yield wasm.Statement('f32.const', node.value.hex()) yield wasm.Statement('f32.const', node.value.hex())
return return
if 'f64' == exp_type: if isinstance(exp_type, wasm.OurTypeFloat64):
assert isinstance(node.value, float) assert isinstance(node.value, float)
# TODO: Size check? # TODO: Size check?
@ -545,27 +558,21 @@ class Visitor:
self, self,
module: wasm.Module, module: wasm.Module,
wlocals: WLocals, wlocals: WLocals,
exp_type: str, exp_type: wasm.OurType,
node: ast.Attribute, node: ast.Attribute,
) -> StatementGenerator: ) -> StatementGenerator:
""" """
Visits an Attribute node as (part of) an expression Visits an Attribute node as (part of) an expression
""" """
if not isinstance(node.value, ast.Name): if not isinstance(node.value, ast.Name):
raise NotImplementedError raise NotImplementedError
if not isinstance(node.ctx, ast.Load): if not isinstance(node.ctx, ast.Load):
raise NotImplementedError raise NotImplementedError
cls_list = [ cls = wlocals[node.value.id]
x
for x in module.classes
if x.name == 'Rectangle' # TODO: STUB, since we can't acces the type properly
]
assert len(cls_list) == 1 assert isinstance(cls, wasm.OurTypeClass), f'Cannot take property of {cls}'
cls = cls_list[0]
member_list = [ member_list = [
x x
@ -577,31 +584,33 @@ class Visitor:
member = member_list[0] member = member_list[0]
yield wasm.Statement('local.get', '$' + node.value.id) yield wasm.Statement('local.get', '$' + node.value.id)
yield wasm.Statement(exp_type + '.load', 'offset=' + str(member.offset)) yield wasm.Statement(exp_type.to_wasm() + '.load', 'offset=' + str(member.offset))
def visit_Name(self, wlocals: WLocals, exp_type: str, node: ast.Name) -> StatementGenerator: def visit_Name(self, wlocals: WLocals, exp_type: wasm.OurType, node: ast.Name) -> StatementGenerator:
""" """
Visits a Name node as (part of) an expression Visits a Name node as (part of) an expression
""" """
assert node.id in wlocals assert node.id in wlocals
if exp_type == 'i64' and wlocals[node.id] == 'i32': if (isinstance(exp_type, wasm.OurTypeInt64)
and isinstance(wlocals[node.id], wasm.OurTypeInt32)):
yield wasm.Statement('local.get', '${}'.format(node.id)) yield wasm.Statement('local.get', '${}'.format(node.id))
yield wasm.Statement('i64.extend_i32_s') yield wasm.Statement('i64.extend_i32_s')
return return
if exp_type == 'f64' and wlocals[node.id] == 'f32': if (isinstance(exp_type, wasm.OurTypeFloat64)
and isinstance(wlocals[node.id], wasm.OurTypeFloat32)):
yield wasm.Statement('local.get', '${}'.format(node.id)) yield wasm.Statement('local.get', '${}'.format(node.id))
yield wasm.Statement('f64.promote_f32') yield wasm.Statement('f64.promote_f32')
return return
assert exp_type == wlocals[node.id] assert exp_type == wlocals[node.id], (exp_type, wlocals[node.id], )
yield wasm.Statement( yield wasm.Statement(
'local.get', 'local.get',
'${}'.format(node.id), '${}'.format(node.id),
) )
def _parse_import_decorator(func_name: str, args: List[ast.expr]) -> Tuple[str, str, str]: def _parse_import_decorator(self, func_name: str, args: List[ast.expr]) -> Tuple[str, str, str]:
""" """
Parses an @import decorator Parses an @import decorator
""" """
@ -623,7 +632,7 @@ def _parse_import_decorator(func_name: str, args: List[ast.expr]) -> Tuple[str,
return module, name, func_name return module, name, func_name
def _parse_import_args(args: ast.arguments) -> List[Tuple[str, str]]: def _parse_import_args(self, args: ast.arguments) -> List[wasm.Param]:
""" """
Parses the arguments for an @imported method Parses the arguments for an @imported method
""" """
@ -640,16 +649,6 @@ def _parse_import_args(args: ast.arguments) -> List[Tuple[str, str]]:
] ]
return [ return [
(arg.arg, _parse_annotation(arg.annotation)) (arg.arg, self._get_type(arg.annotation))
for arg in arg_list for arg in arg_list
] ]
def _parse_annotation(ann: Optional[ast.expr]) -> str:
"""
Parses a typing annotation
"""
assert ann is not None, 'Web Assembly requires type annotations'
assert isinstance(ann, ast.Name)
result = ann.id
assert result in ['i32', 'i64', 'f32', 'f64'], result
return result

View File

@ -2,9 +2,90 @@
Python classes for storing the representation of Web Assembly code Python classes for storing the representation of Web Assembly code
""" """
from typing import Iterable, List, Optional, Tuple, Union from typing import Any, Iterable, List, Optional, Tuple, Union
Param = Tuple[str, str] ###
### This part is more intermediate code
###
class OurType:
def to_wasm(self) -> str:
raise NotImplementedError
def alloc_size(self) -> int:
raise NotImplementedError
class OurTypeNone(OurType):
pass
class OurTypeBool(OurType):
pass
class OurTypeInt32(OurType):
def to_wasm(self) -> str:
return 'i32'
def alloc_size(self) -> int:
return 4
class OurTypeInt64(OurType):
def to_wasm(self) -> str:
return 'i64'
def alloc_size(self) -> int:
return 8
class OurTypeFloat32(OurType):
def to_wasm(self) -> str:
return 'f32'
def alloc_size(self) -> int:
return 4
class OurTypeFloat64(OurType):
def to_wasm(self) -> str:
return 'f64'
def alloc_size(self) -> int:
return 8
class Constant:
"""
TODO
"""
def __init__(self, value: Union[None, bool, int, float]) -> None:
self.value = value
class ClassMember:
"""
Represents a class member
"""
def __init__(self, name: str, type_: OurType, offset: int, default: Optional[Constant]) -> None:
self.name = name
self.type = type_
self.offset = offset
self.default = default
class OurTypeClass(OurType):
"""
Represents a class
"""
def __init__(self, name: str, members: List[ClassMember]) -> None:
self.name = name
self.members = members
def to_wasm(self) -> str:
return 'i32' # WASM uses 32 bit pointers
def alloc_size(self) -> int:
return sum(x.type.alloc_size() for x in self.members)
Param = Tuple[str, OurType]
###
## This part is more actual web assembly
###
class Import: class Import:
""" """
@ -62,7 +143,7 @@ class Function:
exported: bool, exported: bool,
params: Iterable[Param], params: Iterable[Param],
locals_: Iterable[Param], locals_: Iterable[Param],
result: Optional[str], result: Optional[OurType],
statements: Iterable[Statement], statements: Iterable[Statement],
) -> None: ) -> None:
self.name = name self.name = name
@ -79,55 +160,19 @@ class Function:
header = ('(export "{}")' if self.exported else '${}').format(self.name) header = ('(export "{}")' if self.exported else '${}').format(self.name)
for nam, typ in self.params: for nam, typ in self.params:
header += f' (param ${nam} {typ})' header += f' (param ${nam} {typ.to_wasm()})'
if self.result: if self.result:
header += f' (result {self.result})' header += f' (result {self.result.to_wasm()})'
for nam, typ in self.locals: for nam, typ in self.locals:
header += f' (local ${nam} {typ})' header += f' (local ${nam} {typ.to_wasm()})'
return '(func {}\n {}\n )'.format( return '(func {}\n {}\n )'.format(
header, header,
'\n '.join(x.generate() for x in self.statements), '\n '.join(x.generate() for x in self.statements),
) )
class Constant:
"""
TODO
"""
def __init__(self, value: Union[None, bool, int, float]) -> None:
self.value = value
class ClassMember:
"""
Represents a Web Assembly class member
"""
def __init__(self, name: str, type_: str, offset: int, default: Optional[Constant]) -> None:
self.name = name
self.type = type_
self.offset = offset
self.default = default
def alloc_size(self) -> int:
SIZE_MAP = {
'i32': 4,
'i64': 4,
}
return SIZE_MAP[self.type]
class Class:
"""
Represents a Web Assembly class
"""
def __init__(self, name: str, members: List[ClassMember]) -> None:
self.name = name
self.members = members
def alloc_size(self) -> int:
return sum(x.alloc_size() for x in self.members)
class Module: class Module:
""" """
Represents a Web Assembly module Represents a Web Assembly module
@ -135,7 +180,6 @@ class Module:
def __init__(self) -> None: def __init__(self) -> None:
self.imports: List[Import] = [] self.imports: List[Import] = []
self.functions: List[Function] = [] self.functions: List[Function] = []
self.classes: List[Class] = []
def generate(self) -> str: def generate(self) -> str:
""" """

View File

@ -251,6 +251,26 @@ def testEntry() -> i32:
assert 8947 == result.returned_value assert 8947 == result.returned_value
assert [] == result.log_int32_list assert [] == result.log_int32_list
@pytest.mark.integration_test
def test_struct_0():
code_py = """
class CheckedValue:
value: i32
@exported
def testEntry() -> i32:
return helper(CheckedValue(2345))
def helper(cv: CheckedValue) -> i32:
return cv.value
"""
result = Suite(code_py, 'test_call').run_code()
assert 2345 == result.returned_value
assert [] == result.log_int32_list
@pytest.mark.integration_test @pytest.mark.integration_test
def test_struct_1(): def test_struct_1():
code_py = """ code_py = """