phasm/py2wasm/compiler.py
Johan B.W. de Vries 14eede6b06 Cleanup to wasm.py
2022-07-09 12:30:28 +02:00

390 lines
12 KiB
Python

"""
This module contains the code to convert parsed Ourlang into WebAssembly code
"""
from typing import Generator, Tuple
from . import ourlang
from . import wasm
Statements = Generator[wasm.Statement, None, None]
def type_(inp: ourlang.OurType) -> wasm.WasmType:
if isinstance(inp, ourlang.OurTypeNone):
return wasm.WasmTypeNone()
if isinstance(inp, ourlang.OurTypeUInt8):
# WebAssembly has only support for 32 and 64 bits
# So we need to store more memory per byte
return wasm.WasmTypeInt32()
if isinstance(inp, ourlang.OurTypeInt32):
return wasm.WasmTypeInt32()
if isinstance(inp, ourlang.OurTypeInt64):
return wasm.WasmTypeInt64()
if isinstance(inp, ourlang.OurTypeFloat32):
return wasm.WasmTypeFloat32()
if isinstance(inp, ourlang.OurTypeFloat64):
return wasm.WasmTypeFloat64()
if isinstance(inp, (ourlang.Struct, ourlang.OurTypeTuple, ourlang.OurTypeBytes)):
# Structs and tuples are passed as pointer
# And pointers are i32
return wasm.WasmTypeInt32()
raise NotImplementedError(type_, inp)
# Operators that work for i32, i64, f32, f64
OPERATOR_MAP = {
'+': 'add',
'-': 'sub',
'*': 'mul',
'==': 'eq',
}
I32_OPERATOR_MAP = { # TODO: Introduce UInt32 type
'<': 'lt_s',
'>': 'gt_s',
'<=': 'le_s',
'>=': 'ge_s',
}
I64_OPERATOR_MAP = { # TODO: Introduce UInt32 type
'<': 'lt_s',
'>': 'gt_s',
'<=': 'le_s',
'>=': 'ge_s',
}
def expression(inp: ourlang.Expression) -> Statements:
if isinstance(inp, ourlang.ConstantUInt8):
yield wasm.Statement('i32.const', str(inp.value))
return
if isinstance(inp, ourlang.ConstantInt32):
yield wasm.Statement('i32.const', str(inp.value))
return
if isinstance(inp, ourlang.ConstantInt64):
yield wasm.Statement('i64.const', str(inp.value))
return
if isinstance(inp, ourlang.ConstantFloat32):
yield wasm.Statement('f32.const', str(inp.value))
return
if isinstance(inp, ourlang.ConstantFloat64):
yield wasm.Statement('f64.const', str(inp.value))
return
if isinstance(inp, ourlang.VariableReference):
yield wasm.Statement('local.get', '${}'.format(inp.name))
return
if isinstance(inp, ourlang.BinaryOp):
yield from expression(inp.left)
yield from expression(inp.right)
if isinstance(inp.type, ourlang.OurTypeInt32):
if operator := OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'i32.{operator}')
return
if operator := I32_OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'i32.{operator}')
return
if isinstance(inp.type, ourlang.OurTypeInt64):
if operator := OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'i64.{operator}')
return
if operator := I64_OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'i64.{operator}')
return
if isinstance(inp.type, ourlang.OurTypeFloat32):
if operator := OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'f32.{operator}')
return
if isinstance(inp.type, ourlang.OurTypeFloat64):
if operator := OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'f64.{operator}')
return
raise NotImplementedError(expression, inp.type, inp.operator)
if isinstance(inp, ourlang.UnaryOp):
yield from expression(inp.right)
if isinstance(inp.type, ourlang.OurTypeFloat32):
if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS:
yield wasm.Statement(f'f32.{inp.operator}')
return
if isinstance(inp.type, ourlang.OurTypeFloat64):
if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS:
yield wasm.Statement(f'f64.{inp.operator}')
return
if isinstance(inp.type, ourlang.OurTypeInt32):
if inp.operator == 'len':
if isinstance(inp.right.type, ourlang.OurTypeBytes):
yield wasm.Statement('i32.load')
return
raise NotImplementedError(expression, inp.type, inp.operator)
if isinstance(inp, ourlang.FunctionCall):
for arg in inp.arguments:
yield from expression(arg)
yield wasm.Statement('call', '${}'.format(inp.function.name))
return
if isinstance(inp, ourlang.AccessBytesIndex):
if not isinstance(inp.type, ourlang.OurTypeUInt8):
raise NotImplementedError(inp, inp.type)
yield from expression(inp.varref)
yield from expression(inp.index)
yield wasm.Statement('call', '$___access_bytes_index___')
return
if isinstance(inp, ourlang.AccessStructMember):
if isinstance(inp.member.type, ourlang.OurTypeUInt8):
mtyp = 'i32'
else:
# FIXME: Properly implement this
# inp.type.render() is also a hack that doesn't really work consistently
if not isinstance(inp.member.type, (
ourlang.OurTypeInt32, ourlang.OurTypeFloat32,
ourlang.OurTypeInt64, ourlang.OurTypeFloat64,
)):
raise NotImplementedError
mtyp = inp.member.type.render()
yield from expression(inp.varref)
yield wasm.Statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
return
if isinstance(inp, ourlang.AccessTupleMember):
# FIXME: Properly implement this
# inp.type.render() is also a hack that doesn't really work consistently
if not isinstance(inp.type, (
ourlang.OurTypeInt32, ourlang.OurTypeFloat32,
ourlang.OurTypeInt64, ourlang.OurTypeFloat64,
)):
raise NotImplementedError(inp, inp.type)
yield from expression(inp.varref)
yield wasm.Statement(inp.type.render() + '.load', 'offset=' + str(inp.member.offset))
return
raise NotImplementedError(expression, inp)
def statement_return(inp: ourlang.StatementReturn) -> Statements:
yield from expression(inp.value)
yield wasm.Statement('return')
def statement_if(inp: ourlang.StatementIf) -> Statements:
yield from expression(inp.test)
yield wasm.Statement('if')
for stat in inp.statements:
yield from statement(stat)
if inp.else_statements:
yield wasm.Statement('else')
for stat in inp.else_statements:
yield from statement(stat)
yield wasm.Statement('end')
def statement(inp: ourlang.Statement) -> Statements:
if isinstance(inp, ourlang.StatementReturn):
yield from statement_return(inp)
return
if isinstance(inp, ourlang.StatementIf):
yield from statement_if(inp)
return
if isinstance(inp, ourlang.StatementPass):
return
raise NotImplementedError(statement, inp)
def function_argument(inp: Tuple[str, ourlang.OurType]) -> wasm.Param:
return (inp[0], type_(inp[1]), )
def import_(inp: ourlang.Function) -> wasm.Import:
assert inp.imported
return wasm.Import(
'imports',
inp.name,
inp.name,
[
function_argument(x)
for x in inp.posonlyargs
],
type_(inp.returns)
)
def function(inp: ourlang.Function) -> wasm.Function:
assert not inp.imported
if isinstance(inp, ourlang.TupleConstructor):
statements = [
*_generate_tuple_constructor(inp)
]
locals_ = [
('___new_reference___addr', wasm.WasmTypeInt32(), ),
]
elif isinstance(inp, ourlang.StructConstructor):
statements = [
*_generate_struct_constructor(inp)
]
locals_ = [
('___new_reference___addr', wasm.WasmTypeInt32(), ),
]
else:
statements = [
x
for y in inp.statements
for x in statement(y)
]
locals_ = [] # TODO
return wasm.Function(
inp.name,
inp.name if inp.exported else None,
[
function_argument(x)
for x in inp.posonlyargs
],
locals_,
type_(inp.returns),
statements
)
def module(inp: ourlang.Module) -> wasm.Module:
result = wasm.Module()
result.imports = [
import_(x)
for x in inp.functions.values()
if x.imported
]
result.functions = [
_generate____new_reference___(inp),
_generate____access_bytes_index___(inp),
] + [
function(x)
for x in inp.functions.values()
if not x.imported
]
return result
def _generate____new_reference___(mod: ourlang.Module) -> wasm.Function:
return wasm.Function(
'___new_reference___',
'___new_reference___',
[
('alloc_size', type_(mod.types['i32']), ),
],
[
('result', type_(mod.types['i32']), ),
],
type_(mod.types['i32']),
[
wasm.Statement('i32.const', '0'),
wasm.Statement('i32.const', '0'),
wasm.Statement('i32.load'),
wasm.Statement('local.tee', '$result', comment='Address for this call'),
wasm.Statement('local.get', '$alloc_size'),
wasm.Statement('i32.add'),
wasm.Statement('i32.store', comment='Address for the next call'),
wasm.Statement('local.get', '$result'),
],
)
def _generate____access_bytes_index___(mod: ourlang.Module) -> wasm.Function:
return wasm.Function(
'___access_bytes_index___',
None,
[
('byt', type_(mod.types['i32']), ),
('ofs', type_(mod.types['i32']), ),
],
[
],
type_(mod.types['i32']),
[
wasm.Statement('local.get', '$ofs'),
wasm.Statement('local.get', '$byt'),
wasm.Statement('i32.load'),
wasm.Statement('i32.lt_u'),
wasm.Statement('if', comment='$ofs < len($byt)'),
wasm.Statement('local.get', '$byt'),
wasm.Statement('i32.const', '4', comment='Leading size field'),
wasm.Statement('i32.add'),
wasm.Statement('local.get', '$ofs'),
wasm.Statement('i32.add'),
wasm.Statement('i32.load8_u', comment='Within bounds'),
wasm.Statement('return'),
wasm.Statement('end'),
wasm.Statement('i32.const', str(0), comment='Out of bounds'),
wasm.Statement('return'),
],
)
def _generate_tuple_constructor(inp: ourlang.TupleConstructor) -> Statements:
yield wasm.Statement('i32.const', str(inp.tuple.alloc_size()))
yield wasm.Statement('call', '$___new_reference___')
yield wasm.Statement('local.set', '$___new_reference___addr')
for member in inp.tuple.members:
# FIXME: Properly implement this
# inp.type.render() is also a hack that doesn't really work consistently
if not isinstance(member.type, (
ourlang.OurTypeInt32, ourlang.OurTypeFloat32,
ourlang.OurTypeInt64, ourlang.OurTypeFloat64,
)):
raise NotImplementedError
yield wasm.Statement('local.get', '$___new_reference___addr')
yield wasm.Statement('local.get', f'$arg{member.idx}')
yield wasm.Statement(f'{member.type.render()}.store', 'offset=' + str(member.offset))
yield wasm.Statement('local.get', '$___new_reference___addr')
def _generate_struct_constructor(inp: ourlang.StructConstructor) -> Statements:
yield wasm.Statement('i32.const', str(inp.struct.alloc_size()))
yield wasm.Statement('call', '$___new_reference___')
yield wasm.Statement('local.set', '$___new_reference___addr')
for member in inp.struct.members:
if isinstance(member.type, ourlang.OurTypeUInt8):
mtyp = 'i32'
else:
# FIXME: Properly implement this
# inp.type.render() is also a hack that doesn't really work consistently
if not isinstance(member.type, (
ourlang.OurTypeInt32, ourlang.OurTypeFloat32,
ourlang.OurTypeInt64, ourlang.OurTypeFloat64,
)):
raise NotImplementedError
mtyp = member.type.render()
yield wasm.Statement('local.get', '$___new_reference___addr')
yield wasm.Statement('local.get', f'${member.name}')
yield wasm.Statement(f'{mtyp}.store', 'offset=' + str(member.offset))
yield wasm.Statement('local.get', '$___new_reference___addr')