- Tuple => ()

- All tests are now parsed by our own AST
This commit is contained in:
Johan B.W. de Vries 2022-06-06 12:18:09 +02:00
parent e7b72b6a6b
commit 658e442df2
4 changed files with 328 additions and 64 deletions

View File

@ -1,7 +1,7 @@
""" """
Contains the syntax tree for ourlang Contains the syntax tree for ourlang
""" """
from typing import Any, Dict, List, NoReturn, Union, Tuple from typing import Any, Dict, List, Optional, NoReturn, Union, Tuple
import ast import ast
@ -73,6 +73,21 @@ class OurTypeFloat64(OurType):
def render(self) -> str: def render(self) -> str:
return 'f64' return 'f64'
class OurTypeTuple(OurType):
"""
The tuple type, 64 bits wide
"""
__slots__ = ('members', )
members: List[OurType]
def __init__(self) -> None:
self.members = []
def render(self) -> str:
mems = ', '.join(x.render() for x in self.members)
return f'({mems}, )'
class Expression: class Expression:
""" """
An expression within a statement An expression within a statement
@ -227,8 +242,62 @@ class FunctionCall(Expression):
for arg in self.arguments for arg in self.arguments
) )
if isinstance(self.function, StructConstructor):
return f'{self.function.struct.name}({args})'
return f'{self.function.name}({args})' return f'{self.function.name}({args})'
class AccessStructMember(Expression):
"""
Access a struct member for reading of writing
"""
__slots__ = ('varref', 'member', )
varref: VariableReference
member: 'StructMember'
def __init__(self, varref: VariableReference, member: 'StructMember') -> None:
self.varref = varref
self.member = member
def render(self) -> str:
return f'{self.varref.render()}.{self.member.name}'
class AccessTupleMember(Expression):
"""
Access a tuple member for reading of writing
"""
__slots__ = ('varref', 'member_idx', 'member_type', )
varref: VariableReference
member_idx: int
member_type: 'OurType'
def __init__(self, varref: VariableReference, member_idx: int, member_type: 'OurType') -> None:
self.varref = varref
self.member_idx = member_idx
self.member_type = member_type
def render(self) -> str:
return f'{self.varref.render()}[{self.member_idx}]'
class TupleCreation(Expression):
"""
Create a tuple instance
"""
__slots__ = ('type', 'members', )
type: OurTypeTuple
members: List[Expression]
def __init__(self, type_: OurTypeTuple) -> None:
self.type = type_
self.members = []
def render(self) -> str:
mems = ', '. join(x.render() for x in self.members)
return f'({mems}, )'
class Statement: class Statement:
""" """
A statement within a function A statement within a function
@ -306,11 +375,12 @@ class Function:
""" """
A function processes input and produces output A function processes input and produces output
""" """
__slots__ = ('name', 'lineno', 'exported', 'statements', 'returns', 'posonlyargs', ) __slots__ = ('name', 'lineno', 'exported', 'buildin', 'statements', 'returns', 'posonlyargs', )
name: str name: str
lineno: int lineno: int
exported: bool exported: bool
buildin: bool
statements: List[Statement] statements: List[Statement]
returns: OurType returns: OurType
posonlyargs: List[Tuple[str, OurType]] posonlyargs: List[Tuple[str, OurType]]
@ -319,6 +389,7 @@ class Function:
self.name = name self.name = name
self.lineno = lineno self.lineno = lineno
self.exported = False self.exported = False
self.buildin = False
self.statements = [] self.statements = []
self.returns = OurTypeNone() self.returns = OurTypeNone()
self.posonlyargs = [] self.posonlyargs = []
@ -344,6 +415,24 @@ class Function:
result += f' {line}\n' if line else '\n' result += f' {line}\n' if line else '\n'
return result return result
class StructConstructor(Function):
"""
The constructor method for a struct
"""
__slots__ = ('struct', )
struct: 'Struct'
def __init__(self, struct: 'Struct') -> None:
super().__init__(f'@{struct.name}@__init___@', -1)
self.returns = struct
for mem in struct.members:
self.posonlyargs.append((mem.name, mem.type, ))
self.struct = struct
class StructMember: class StructMember:
""" """
Represents a struct member Represents a struct member
@ -353,7 +442,7 @@ class StructMember:
self.type = type_ self.type = type_
self.offset = offset self.offset = offset
class Struct: class Struct(OurType):
""" """
A struct has named properties A struct has named properties
""" """
@ -368,13 +457,35 @@ class Struct:
self.lineno = lineno self.lineno = lineno
self.members = [] self.members = []
def get_member(self, name: str) -> Optional[StructMember]:
"""
Returns a member by name
"""
for mem in self.members:
if mem.name == name:
return mem
return None
def render(self) -> str: def render(self) -> str:
""" """
Renders the function back to source code format Renders the type back to source code format
This'll look like Python code. This'll look like Python code.
""" """
return '?' return self.name
def render_definition(self) -> str:
"""
Renders the definition back to source code format
This'll look like Python code.
"""
result = f'class {self.name}:\n'
for mem in self.members:
result += f' {mem.name}: {mem.type.render()}\n'
return result
class Module: class Module:
""" """
@ -396,6 +507,14 @@ class Module:
self.functions = {} self.functions = {}
self.structs = {} self.structs = {}
# sqrt is guaranteed by wasm, so we should provide it
# ATM it's a 32 bit variant, but that might change.
sqrt = Function('sqrt', -2)
sqrt.buildin = True
sqrt.returns = self.types['f32']
sqrt.posonlyargs = [('@', self.types['f32'], )]
self.functions[sqrt.name] = sqrt
def render(self) -> str: def render(self) -> str:
""" """
Renders the module back to source code format Renders the module back to source code format
@ -403,10 +522,21 @@ class Module:
This'll look like Python code. This'll look like Python code.
""" """
result = '' result = ''
for struct in self.structs.values():
if result:
result += '\n'
result += struct.render_definition()
for function in self.functions.values(): for function in self.functions.values():
if function.lineno < 0:
# Buildin (-2) or auto generated (-1)
continue
if result: if result:
result += '\n' result += '\n'
result += function.render() result += function.render()
return result return result
class StaticError(Exception): class StaticError(Exception):
@ -431,8 +561,10 @@ class OurVisitor:
_not_implemented(not node.type_ignores, 'Module.type_ignores') _not_implemented(not node.type_ignores, 'Module.type_ignores')
# Second pass for the types
for stmt in node.body: for stmt in node.body:
res = self.visit_Module_stmt(module, stmt) res = self.pre_visit_Module_stmt(module, stmt)
if isinstance(res, Function): if isinstance(res, Function):
if res.name in module.functions: if res.name in module.functions:
@ -449,19 +581,26 @@ class OurVisitor:
) )
module.structs[res.name] = res module.structs[res.name] = res
constructor = StructConstructor(res)
module.functions[constructor.name] = constructor
# Second pass for the function bodies
for stmt in node.body:
self.visit_Module_stmt(module, stmt)
return module return module
def visit_Module_stmt(self, module: Module, node: ast.stmt) -> Union[Function, Struct]: def pre_visit_Module_stmt(self, module: Module, node: ast.stmt) -> Union[Function, Struct]:
if isinstance(node, ast.FunctionDef): if isinstance(node, ast.FunctionDef):
return self.visit_Module_FunctionDef(module, node) return self.pre_visit_Module_FunctionDef(module, node)
if isinstance(node, ast.ClassDef): if isinstance(node, ast.ClassDef):
return self.visit_Module_ClassDef(module, node) return self.pre_visit_Module_ClassDef(module, node)
raise NotImplementedError(f'{node} on Module') raise NotImplementedError(f'{node} on Module')
def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> Function: def pre_visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> Function:
function = Function(node.name, node.lineno) function = Function(node.name, node.lineno)
_not_implemented(not node.args.posonlyargs, 'FunctionDef.args.posonlyargs') _not_implemented(not node.args.posonlyargs, 'FunctionDef.args.posonlyargs')
@ -496,18 +635,9 @@ class OurVisitor:
_not_implemented(not node.type_comment, 'FunctionDef.type_comment') _not_implemented(not node.type_comment, 'FunctionDef.type_comment')
# Deferred parsing
our_locals = dict(function.posonlyargs)
for stmt in node.body:
function.statements.append(
self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt)
)
return function return function
def visit_Module_ClassDef(self, module: Module, node: ast.ClassDef) -> Struct: def pre_visit_Module_ClassDef(self, module: Module, node: ast.ClassDef) -> Struct:
struct = Struct(node.name, node.lineno) struct = Struct(node.name, node.lineno)
_not_implemented(not node.bases, 'ClassDef.bases') _not_implemented(not node.bases, 'ClassDef.bases')
@ -536,6 +666,26 @@ class OurVisitor:
return struct return struct
def visit_Module_stmt(self, module: Module, node: ast.stmt) -> None:
if isinstance(node, ast.FunctionDef):
self.visit_Module_FunctionDef(module, node)
return
if isinstance(node, ast.ClassDef):
return
raise NotImplementedError(f'{node} on Module')
def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None:
function = module.functions[node.name]
our_locals = dict(function.posonlyargs)
for stmt in node.body:
function.statements.append(
self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt)
)
def visit_Module_FunctionDef_stmt(self, module: Module, function: Function, our_locals: OurLocals, node: ast.stmt) -> Statement: def visit_Module_FunctionDef_stmt(self, module: Module, function: Function, our_locals: OurLocals, node: ast.stmt) -> Statement:
if isinstance(node, ast.Return): if isinstance(node, ast.Return):
if node.value is None: if node.value is None:
@ -571,6 +721,8 @@ class OurVisitor:
operator = '+' operator = '+'
elif isinstance(node.op, ast.Sub): elif isinstance(node.op, ast.Sub):
operator = '-' operator = '-'
elif isinstance(node.op, ast.Mult):
operator = '*'
else: else:
raise NotImplementedError(f'Operator {node.op}') raise NotImplementedError(f'Operator {node.op}')
@ -602,6 +754,10 @@ class OurVisitor:
if isinstance(node.ops[0], ast.Gt): if isinstance(node.ops[0], ast.Gt):
operator = '>' operator = '>'
elif isinstance(node.ops[0], ast.Eq):
operator = '=='
elif isinstance(node.ops[0], ast.Lt):
operator = '<'
else: else:
raise NotImplementedError(f'Operator {node.ops}') raise NotImplementedError(f'Operator {node.ops}')
@ -615,33 +771,23 @@ class OurVisitor:
) )
if isinstance(node, ast.Call): if isinstance(node, ast.Call):
if node.keywords: return self.visit_Module_FunctionDef_Call(module, function, our_locals, exp_type, node)
_raise_static_error(node, 'Keyword calling not supported') # Yet?
if not isinstance(node.func, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {node.func}')
if not isinstance(node.func.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if node.func.id not in module.functions:
_raise_static_error(node, 'Call to undefined function')
func = module.functions[node.func.id]
if func.returns != exp_type:
_raise_static_error(node, f'Function does not return {exp_type.render()}')
result = FunctionCall(func)
result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, expr)
for expr in node.args
)
return result
if isinstance(node, ast.Constant): if isinstance(node, ast.Constant):
return self.visit_Module_FunctionDef_Constant( return self.visit_Module_FunctionDef_Constant(
module, function, exp_type, node, module, function, exp_type, node,
) )
if isinstance(node, ast.Attribute):
return self.visit_Module_FunctionDef_Attribute(
module, function, our_locals, exp_type, node,
)
if isinstance(node, ast.Subscript):
return self.visit_Module_FunctionDef_Subscript(
module, function, our_locals, exp_type, node,
)
if isinstance(node, ast.Name): if isinstance(node, ast.Name):
if not isinstance(node.ctx, ast.Load): if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
@ -649,8 +795,120 @@ class OurVisitor:
if node.id in our_locals: if node.id in our_locals:
return VariableReference(our_locals[node.id], node.id) return VariableReference(our_locals[node.id], node.id)
if isinstance(node, ast.Tuple):
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if not isinstance(exp_type, OurTypeTuple):
_raise_static_error(node, f'Expression is expecting a {exp_type.render()}, not a tuple')
if len(exp_type.members) != len(node.elts):
_raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given')
result = TupleCreation(exp_type)
result.members = [
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_type, arg_node)
for arg_node, arg_type in zip(node.elts, exp_type.members)
]
return result
raise NotImplementedError(f'{node} as expr in FunctionDef') raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, exp_type: OurType, node: ast.Call) -> FunctionCall:
if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet?
if not isinstance(node.func, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {node.func}')
if not isinstance(node.func.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if node.func.id in module.structs:
struct = module.structs[node.func.id]
struct_constructor = StructConstructor(struct)
func = module.functions[struct_constructor.name]
else:
if node.func.id not in module.functions:
_raise_static_error(node, 'Call to undefined function')
func = module.functions[node.func.id]
if func.returns != exp_type:
_raise_static_error(node, f'Function {node.func.id} does not return {exp_type.render()}')
if len(func.posonlyargs) != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given')
result = FunctionCall(func)
result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_type, arg_expr)
for arg_expr, (_, arg_type) in zip(node.args, func.posonlyargs)
)
return result
def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, exp_type: OurType, node: ast.Attribute) -> Expression:
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if not node.value.id in our_locals:
_raise_static_error(node, f'Undefined variable {node.value.id}')
node_typ = our_locals[node.value.id]
if not isinstance(node_typ, Struct):
_raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
member = node_typ.get_member(node.attr)
if member is None:
_raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}')
if exp_type != member.type:
_raise_static_error(node, f'Expected {exp_type.render()}, got {member.type.render()} instead')
return AccessStructMember(
VariableReference(node_typ, node.value.id),
member,
)
def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, exp_type: OurType, node: ast.Subscript) -> Expression:
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name')
if not isinstance(node.slice, ast.Index):
_raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.slice.value, ast.Constant):
_raise_static_error(node, 'Must subscript using a constant index') # FIXME: Implement variable indexes
if not isinstance(node.slice.value.value, int):
_raise_static_error(node, 'Must subscript using a constant integer index')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if not node.value.id in our_locals:
_raise_static_error(node, f'Undefined variable {node.value.id}')
node_typ = our_locals[node.value.id]
if not isinstance(node_typ, OurTypeTuple):
_raise_static_error(node, f'Cannot take index of non-tuple {node.value.id}')
if len(node_typ.members) <= node.slice.value.value:
_raise_static_error(node, f'Index {node.slice.value.value} out of bounds for tuple {node.value.id}')
member = node_typ.members[node.slice.value.value]
if exp_type != member:
_raise_static_error(node, f'Expected {exp_type.render()}, got {member.render()} instead')
return AccessTupleMember(
VariableReference(node_typ, node.value.id),
node.slice.value.value,
member,
)
def visit_Module_FunctionDef_Constant(self, module: Module, function: Function, exp_type: OurType, node: ast.Constant) -> Expression: def visit_Module_FunctionDef_Constant(self, module: Module, function: Function, exp_type: OurType, node: ast.Constant) -> Expression:
del module del module
del function del function
@ -696,11 +954,25 @@ class OurVisitor:
if not isinstance(node.ctx, ast.Load): if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
if node.id not in module.types: if node.id in module.types:
_raise_static_error(node, 'Unrecognized type')
return module.types[node.id] return module.types[node.id]
if node.id in module.structs:
return module.structs[node.id]
_raise_static_error(node, f'Unrecognized type {node.id}')
if isinstance(node, ast.Tuple):
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
result = OurTypeTuple()
result.members = [
self.visit_type(module, elt)
for elt in node.elts
]
return result
raise NotImplementedError(f'{node} as type') raise NotImplementedError(f'{node} as type')
def _not_implemented(check: Any, msg: str) -> None: def _not_implemented(check: Any, msg: str) -> None:

View File

@ -41,14 +41,12 @@ class Visitor:
if isinstance(name, ast.expr): if isinstance(name, ast.expr):
if isinstance(name, ast.Name): if isinstance(name, ast.Name):
name = name.id name = name.id
elif isinstance(name, ast.Subscript) and isinstance(name.value, ast.Name) and name.value.id == 'Tuple': elif isinstance(name, ast.Tuple):
assert isinstance(name.ctx, ast.Load) assert isinstance(name.ctx, ast.Load)
assert isinstance(name.slice, ast.Index)
assert isinstance(name.slice.value, ast.Tuple)
args: List[wasm.TupleMember] = [] args: List[wasm.TupleMember] = []
offset = 0 offset = 0
for name_arg in name.slice.value.elts: for name_arg in name.elts:
arg = wasm.TupleMember( arg = wasm.TupleMember(
self._get_type(name_arg), self._get_type(name_arg),
offset offset

View File

@ -14,7 +14,7 @@ import wasmer_compiler_cranelift
import wasmtime import wasmtime
from py2wasm.utils import process from py2wasm.utils import our_process, process
DASHES = '-' * 16 DASHES = '-' * 16
@ -69,6 +69,9 @@ class Suite:
Returned is an object with the results set Returned is an object with the results set
""" """
our_module = our_process(self.code_py, self.test_name)
assert self.code_py == '\n' + our_module.render() # \n for formatting in tests
code_wat = process(self.code_py, self.test_name) code_wat = process(self.code_py, self.test_name)
sys.stderr.write(f'{DASHES} Assembly {DASHES}\n') sys.stderr.write(f'{DASHES} Assembly {DASHES}\n')

View File

@ -179,6 +179,7 @@ def testEntry(a: i32) -> i32:
assert 1 == result.returned_value assert 1 == result.returned_value
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.skip('Such a return is not how things should be')
def test_if_complex(): def test_if_complex():
code_py = """ code_py = """
@exported @exported
@ -254,7 +255,6 @@ def testEntry() -> i32:
@pytest.mark.integration_test @pytest.mark.integration_test
def test_struct_0(): def test_struct_0():
code_py = """ code_py = """
class CheckedValue: class CheckedValue:
value: i32 value: i32
@ -274,11 +274,10 @@ def helper(cv: CheckedValue) -> i32:
@pytest.mark.integration_test @pytest.mark.integration_test
def test_struct_1(): def test_struct_1():
code_py = """ code_py = """
class Rectangle: class Rectangle:
height: i32 height: i32
width: i32 width: i32
border: i32 # = 5 border: i32
@exported @exported
def testEntry() -> i32: def testEntry() -> i32:
@ -296,21 +295,17 @@ def helper(shape: Rectangle) -> i32:
@pytest.mark.integration_test @pytest.mark.integration_test
def test_struct_2(): def test_struct_2():
code_py = """ code_py = """
class Rectangle: class Rectangle:
height: i32 height: i32
width: i32 width: i32
border: i32 # = 5 border: i32
@exported @exported
def testEntry() -> i32: def testEntry() -> i32:
return helper(Rectangle(100, 150, 2), Rectangle(200, 90, 3)) return helper(Rectangle(100, 150, 2), Rectangle(200, 90, 3))
def helper(shape1: Rectangle, shape2: Rectangle) -> i32: def helper(shape1: Rectangle, shape2: Rectangle) -> i32:
return ( return shape1.height + shape1.width + shape1.border + shape2.height + shape2.width + shape2.border
shape1.height + shape1.width + shape1.border
+ shape2.height + shape2.width + shape2.border
)
""" """
result = Suite(code_py, 'test_call').run_code() result = Suite(code_py, 'test_call').run_code()
@ -321,12 +316,11 @@ def helper(shape1: Rectangle, shape2: Rectangle) -> i32:
@pytest.mark.integration_test @pytest.mark.integration_test
def test_tuple_int(): def test_tuple_int():
code_py = """ code_py = """
@exported @exported
def testEntry() -> i32: def testEntry() -> i32:
return helper((24, 57, 80, )) return helper((24, 57, 80, ))
def helper(vector: Tuple[i32, i32, i32]) -> i32: def helper(vector: (i32, i32, i32, )) -> i32:
return vector[0] + vector[1] + vector[2] return vector[0] + vector[1] + vector[2]
""" """
@ -338,14 +332,11 @@ def helper(vector: Tuple[i32, i32, i32]) -> i32:
@pytest.mark.integration_test @pytest.mark.integration_test
def test_tuple_float(): def test_tuple_float():
code_py = """ code_py = """
@exported @exported
def testEntry() -> f32: def testEntry() -> f32:
return helper((1.0, 2.0, 3.0, )) return helper((1.0, 2.0, 3.0, ))
def helper(v: Tuple[f32, f32, f32]) -> f32: def helper(v: (f32, f32, f32, )) -> f32:
# sqrt is guaranteed by wasm, so we can use it
# ATM it's a 32 bit variant, but that might change.
return sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]) return sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])
""" """