phasm/phasm/compiler.py
2022-08-06 12:43:46 +02:00

591 lines
18 KiB
Python

"""
This module contains the code to convert parsed Ourlang into WebAssembly code
"""
from typing import Generator
from . import ourlang
from . import typing
from . import wasm
from . import wasmeasy
from .wasmeasy import i32, i64
Statements = Generator[wasm.Statement, None, None]
LOAD_STORE_TYPE_MAP = {
typing.TypeUInt8: 'i32',
typing.TypeUInt32: 'i32',
typing.TypeUInt64: 'i64',
typing.TypeInt32: 'i32',
typing.TypeInt64: 'i64',
typing.TypeFloat32: 'f32',
typing.TypeFloat64: 'f64',
}
"""
When generating code, we sometimes need to load or store simple values
"""
def phasm_compile(inp: ourlang.Module) -> wasm.Module:
"""
Public method for compiling a parsed Phasm module into
a WebAssembly module
"""
return module(inp)
def type_(inp: typing.TypeBase) -> wasm.WasmType:
"""
Compile: type
"""
if isinstance(inp, typing.TypeNone):
return wasm.WasmTypeNone()
if isinstance(inp, typing.TypeUInt8):
# WebAssembly has only support for 32 and 64 bits
# So we need to store more memory per byte
return wasm.WasmTypeInt32()
if isinstance(inp, typing.TypeUInt32):
return wasm.WasmTypeInt32()
if isinstance(inp, typing.TypeUInt64):
return wasm.WasmTypeInt64()
if isinstance(inp, typing.TypeInt32):
return wasm.WasmTypeInt32()
if isinstance(inp, typing.TypeInt64):
return wasm.WasmTypeInt64()
if isinstance(inp, typing.TypeFloat32):
return wasm.WasmTypeFloat32()
if isinstance(inp, typing.TypeFloat64):
return wasm.WasmTypeFloat64()
if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeBytes)):
# Structs and tuples are passed as pointer
# And pointers are i32
return wasm.WasmTypeInt32()
raise NotImplementedError(type_, inp)
# Operators that work for i32, i64, f32, f64
OPERATOR_MAP = {
'+': 'add',
'-': 'sub',
'*': 'mul',
'==': 'eq',
}
U32_OPERATOR_MAP = {
'<': 'lt_u',
'>': 'gt_u',
'<=': 'le_u',
'>=': 'ge_u',
}
U64_OPERATOR_MAP = {
'<': 'lt_u',
'>': 'gt_u',
'<=': 'le_u',
'>=': 'ge_u',
}
I32_OPERATOR_MAP = {
'<': 'lt_s',
'>': 'gt_s',
'<=': 'le_s',
'>=': 'ge_s',
}
I64_OPERATOR_MAP = {
'<': 'lt_s',
'>': 'gt_s',
'<=': 'le_s',
'>=': 'ge_s',
}
def expression(inp: ourlang.Expression) -> Statements:
"""
Compile: Any expression
"""
if isinstance(inp, ourlang.ConstantUInt8):
yield i32.const(inp.value)
return
if isinstance(inp, ourlang.ConstantUInt32):
yield i32.const(inp.value)
return
if isinstance(inp, ourlang.ConstantUInt64):
yield i64.const(inp.value)
return
if isinstance(inp, ourlang.ConstantInt32):
yield i32.const(inp.value)
return
if isinstance(inp, ourlang.ConstantInt64):
yield i64.const(inp.value)
return
if isinstance(inp, ourlang.ConstantFloat32):
yield wasm.Statement('f32.const', str(inp.value))
return
if isinstance(inp, ourlang.ConstantFloat64):
yield wasm.Statement('f64.const', str(inp.value))
return
if isinstance(inp, ourlang.VariableReference):
yield wasm.Statement('local.get', '${}'.format(inp.name))
return
if isinstance(inp, ourlang.BinaryOp):
yield from expression(inp.left)
yield from expression(inp.right)
if isinstance(inp.type, typing.TypeUInt32):
if operator := OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'i32.{operator}')
return
if operator := U32_OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'i32.{operator}')
return
if isinstance(inp.type, typing.TypeUInt64):
if operator := OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'i64.{operator}')
return
if operator := U64_OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'i64.{operator}')
return
if isinstance(inp.type, typing.TypeInt32):
if operator := OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'i32.{operator}')
return
if operator := I32_OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'i32.{operator}')
return
if isinstance(inp.type, typing.TypeInt64):
if operator := OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'i64.{operator}')
return
if operator := I64_OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'i64.{operator}')
return
if isinstance(inp.type, typing.TypeFloat32):
if operator := OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'f32.{operator}')
return
if isinstance(inp.type, typing.TypeFloat64):
if operator := OPERATOR_MAP.get(inp.operator, None):
yield wasm.Statement(f'f64.{operator}')
return
raise NotImplementedError(expression, inp.type, inp.operator)
if isinstance(inp, ourlang.UnaryOp):
yield from expression(inp.right)
if isinstance(inp.type, typing.TypeFloat32):
if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS:
yield wasm.Statement(f'f32.{inp.operator}')
return
if isinstance(inp.type, typing.TypeFloat64):
if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS:
yield wasm.Statement(f'f64.{inp.operator}')
return
if isinstance(inp.type, typing.TypeInt32):
if inp.operator == 'len':
if isinstance(inp.right.type, typing.TypeBytes):
yield i32.load()
return
raise NotImplementedError(expression, inp.type, inp.operator)
if isinstance(inp, ourlang.FunctionCall):
for arg in inp.arguments:
yield from expression(arg)
yield wasm.Statement('call', '${}'.format(inp.function.name))
return
if isinstance(inp, ourlang.AccessBytesIndex):
if not isinstance(inp.type, typing.TypeUInt8):
raise NotImplementedError(inp, inp.type)
yield from expression(inp.varref)
yield from expression(inp.index)
yield wasm.Statement('call', '$___access_bytes_index___')
return
if isinstance(inp, ourlang.AccessStructMember):
mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, inp.member)
yield from expression(inp.varref)
yield wasm.Statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
return
if isinstance(inp, ourlang.AccessTupleMember):
mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, inp.member)
yield from expression(inp.varref)
yield wasm.Statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
return
raise NotImplementedError(expression, inp)
def statement_return(inp: ourlang.StatementReturn) -> Statements:
"""
Compile: Return statement
"""
yield from expression(inp.value)
yield wasm.Statement('return')
def statement_if(inp: ourlang.StatementIf) -> Statements:
"""
Compile: If statement
"""
yield from expression(inp.test)
yield wasm.Statement('if')
for stat in inp.statements:
yield from statement(stat)
if inp.else_statements:
yield wasm.Statement('else')
for stat in inp.else_statements:
yield from statement(stat)
yield wasm.Statement('end')
def statement(inp: ourlang.Statement) -> Statements:
"""
Compile: any statement
"""
if isinstance(inp, ourlang.StatementReturn):
yield from statement_return(inp)
return
if isinstance(inp, ourlang.StatementIf):
yield from statement_if(inp)
return
if isinstance(inp, ourlang.StatementPass):
return
raise NotImplementedError(statement, inp)
def function_argument(inp: ourlang.FunctionParam) -> wasm.Param:
"""
Compile: function argument
"""
return (inp[0], type_(inp[1]), )
def import_(inp: ourlang.Function) -> wasm.Import:
"""
Compile: imported function
"""
assert inp.imported
return wasm.Import(
'imports',
inp.name,
inp.name,
[
function_argument(x)
for x in inp.posonlyargs
],
type_(inp.returns)
)
def function(inp: ourlang.Function) -> wasm.Function:
"""
Compile: function
"""
assert not inp.imported
if isinstance(inp, ourlang.TupleConstructor):
statements = [
*_generate_tuple_constructor(inp)
]
locals_ = [
('___new_reference___addr', wasm.WasmTypeInt32(), ),
]
elif isinstance(inp, ourlang.StructConstructor):
statements = [
*_generate_struct_constructor(inp)
]
locals_ = [
('___new_reference___addr', wasm.WasmTypeInt32(), ),
]
else:
statements = [
x
for y in inp.statements
for x in statement(y)
]
locals_ = [] # FIXME: Implement function locals, if required
return wasm.Function(
inp.name,
inp.name if inp.exported else None,
[
function_argument(x)
for x in inp.posonlyargs
],
locals_,
type_(inp.returns),
statements
)
def module(inp: ourlang.Module) -> wasm.Module:
"""
Compile: module
"""
result = wasm.Module()
result.imports = [
import_(x)
for x in inp.functions.values()
if x.imported
]
result.functions = [
_generate____new_reference___(inp),
_generate_stdlib_alloc___init__(inp),
_generate_stdlib_alloc___find_free_block__(inp),
_generate_stdlib_alloc___alloc__(inp),
_generate____access_bytes_index___(inp),
] + [
function(x)
for x in inp.functions.values()
if not x.imported
]
return result
def _generate____new_reference___(mod: ourlang.Module) -> wasm.Function:
return wasm.Function(
'___new_reference___',
'___new_reference___',
[
('alloc_size', type_(mod.types['i32']), ),
],
[
('result', type_(mod.types['i32']), ),
],
type_(mod.types['i32']),
[
i32.const(0),
i32.const(0),
i32.load(),
wasm.Statement('local.tee', '$result', comment='Address for this call'),
wasm.Statement('local.get', '$alloc_size'),
i32.add(),
i32.store(comment='Address for the next call'),
wasm.Statement('local.get', '$result'),
],
)
STDLIB_ALLOC__IDENTIFIER = 0xA1C0
STDLIB_ALLOC__RESERVED0 = 0x04
STDLIB_ALLOC__FREE_BLOCK = 0x08
STDLIB_ALLOC__UNALLOC_ADDR = 0x0C
def _generate_stdlib_alloc___init__(mod: ourlang.Module) -> wasm.Function:
return wasm.Function(
'stdlib.alloc.__init__',
'stdlib.alloc.__init__',
[],
[],
wasm.WasmTypeNone(),
[
i32.const(0),
i32.load(),
i32.const(STDLIB_ALLOC__IDENTIFIER),
i32.eq(),
*wasmeasy.if_(
wasm.Statement('return', comment='Already set up'),
),
i32.const(STDLIB_ALLOC__RESERVED0, comment='Reserved'),
i32.const(0),
i32.store(),
i32.const(STDLIB_ALLOC__FREE_BLOCK,
comment='Address of next free block'),
i32.const(0, comment='None to start with'),
i32.store(),
i32.const(STDLIB_ALLOC__UNALLOC_ADDR,
comment='Address of first unallocated byte'),
i32.const(0x10),
i32.store(),
i32.const(0, comment='Done setting up'),
i32.const(STDLIB_ALLOC__IDENTIFIER),
i32.store(),
],
)
def _generate_stdlib_alloc___find_free_block__(mod: ourlang.Module) -> wasm.Function:
return wasm.Function(
'stdlib.alloc.__find_free_block__',
'stdlib.alloc.__find_free_block__',
[
('alloc_size', type_(mod.types['i32']), ),
],
[
('result', type_(mod.types['i32']), ),
],
type_(mod.types['i32']),
[
i32.const(STDLIB_ALLOC__FREE_BLOCK),
i32.load(),
i32.const(0),
i32.eq(),
*wasmeasy.if_(
i32.const(0),
wasm.Statement('return'),
),
wasm.Statement('unreachable'),
],
)
def _generate_stdlib_alloc___alloc__(mod: ourlang.Module) -> wasm.Function:
return wasm.Function(
'stdlib.alloc.__alloc__',
'stdlib.alloc.__alloc__',
[
('alloc_size', type_(mod.types['i32']), ),
],
[
('result', type_(mod.types['i32']), ),
],
type_(mod.types['i32']),
[
i32.const(0),
i32.load(),
i32.const(STDLIB_ALLOC__IDENTIFIER),
i32.ne(),
*wasmeasy.if_(
wasm.Statement('unreachable'),
),
wasm.Statement('local.get', '$alloc_size'),
wasm.Statement('call', '$stdlib.alloc.__find_free_block__'),
wasm.Statement('local.set', '$result'),
# Check if there was a free block
wasm.Statement('local.get', '$result'),
i32.const(0),
i32.eq(),
*wasmeasy.if_(
# Use unallocated space
i32.const(STDLIB_ALLOC__UNALLOC_ADDR),
i32.const(STDLIB_ALLOC__UNALLOC_ADDR),
i32.load(),
wasm.Statement('local.tee', '$result'),
# Updated unalloc pointer (address already set on stack)
i32.const(4), # Header size
i32.add(),
wasm.Statement('local.get', '$alloc_size'),
i32.add(),
i32.store('offset=0'),
),
# Store block size
wasm.Statement('local.get', '$result'),
wasm.Statement('local.get', '$alloc_size'),
i32.store('offset=0'),
# Return address of the allocated bytes
wasm.Statement('local.get', '$result'),
i32.const(4), # Header size
i32.add(),
],
)
def _generate____access_bytes_index___(mod: ourlang.Module) -> wasm.Function:
return wasm.Function(
'___access_bytes_index___',
None,
[
('byt', type_(mod.types['i32']), ),
('ofs', type_(mod.types['i32']), ),
],
[
],
type_(mod.types['i32']),
[
wasm.Statement('local.get', '$ofs'),
wasm.Statement('local.get', '$byt'),
i32.load(),
wasm.Statement('i32.lt_u'),
wasm.Statement('if', comment='$ofs < len($byt)'),
wasm.Statement('local.get', '$byt'),
wasm.Statement('i32.const', '4', comment='Leading size field'),
wasm.Statement('i32.add'),
wasm.Statement('local.get', '$ofs'),
wasm.Statement('i32.add'),
wasm.Statement('i32.load8_u', comment='Within bounds'),
wasm.Statement('return'),
wasm.Statement('end'),
wasm.Statement('i32.const', str(0), comment='Out of bounds'),
wasm.Statement('return'),
],
)
def _generate_tuple_constructor(inp: ourlang.TupleConstructor) -> Statements:
yield wasm.Statement('i32.const', str(inp.tuple.alloc_size()))
yield wasm.Statement('call', '$___new_reference___')
yield wasm.Statement('local.set', '$___new_reference___addr')
for member in inp.tuple.members:
mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, member)
yield wasm.Statement('local.get', '$___new_reference___addr')
yield wasm.Statement('local.get', f'$arg{member.idx}')
yield wasm.Statement(f'{mtyp}.store', 'offset=' + str(member.offset))
yield wasm.Statement('local.get', '$___new_reference___addr')
def _generate_struct_constructor(inp: ourlang.StructConstructor) -> Statements:
yield wasm.Statement('i32.const', str(inp.struct.alloc_size()))
yield wasm.Statement('call', '$___new_reference___')
yield wasm.Statement('local.set', '$___new_reference___addr')
for member in inp.struct.members:
mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, member)
yield wasm.Statement('local.get', '$___new_reference___addr')
yield wasm.Statement('local.get', f'${member.name}')
yield wasm.Statement(f'{mtyp}.store', 'offset=' + str(member.offset))
yield wasm.Statement('local.get', '$___new_reference___addr')