Implements ceil, floor, trunc, nearest

To round of the f32 / f64 wasm supported opcodes.

This also means we can remove the now outdated
WEBASSEMBLY_BUILTIN_FLOAT_OPS.
This commit is contained in:
Johan B.W. de Vries 2025-04-06 16:38:57 +02:00
parent 9bc8d94ffd
commit 46dbc90475
8 changed files with 113 additions and 57 deletions

View File

@ -13,6 +13,7 @@
- Allocation is done using pointers for members, is this desired? - Allocation is done using pointers for members, is this desired?
- Functions don't seem to be a thing on typing level yet? - Functions don't seem to be a thing on typing level yet?
- static_array and tuple should probably not be PrimitiveType3, but instead subclass AppliedType3? - static_array and tuple should probably not be PrimitiveType3, but instead subclass AppliedType3?
- See if we want to replace Fractional with Real, and add Rational, Irrationl, Algebraic, Transendental
- test_bitwise_or_inv_type - test_bitwise_or_inv_type
- test_bytes_index_out_of_bounds vs static trap(?) - test_bytes_index_out_of_bounds vs static trap(?)
@ -22,8 +23,9 @@
- Either there should be more of them or less - Either there should be more of them or less
- At first glance, looks like failure in the typing system - At first glance, looks like failure in the typing system
- Related to the FIXME in phasm_type3? - Related to the FIXME in phasm_type3?
- WEBASSEMBLY_BUILTIN_FLOAT_OPS and WEBASSEMBLY_BUILTIN_BYTES_OPS are special cased - WEBASSEMBLY_BUILTIN_BYTES_OPS is special cased
- Should be part of a prelude - Should be part of a prelude (?)
- In Haskell this is not a type class
- Casting is not implemented except u32 which is special cased - Casting is not implemented except u32 which is special cased
- Parser is putting stuff in ModuleDataBlock - Parser is putting stuff in ModuleDataBlock
- Compiler should probably do that - Compiler should probably do that

View File

@ -90,8 +90,7 @@ def expression(inp: ourlang.Expression) -> str:
if isinstance(inp, ourlang.UnaryOp): if isinstance(inp, ourlang.UnaryOp):
if ( if (
inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS inp.operator in ourlang.WEBASSEMBLY_BUILTIN_BYTES_OPS):
or inp.operator in ourlang.WEBASSEMBLY_BUILTIN_BYTES_OPS):
return f'{inp.operator}({expression(inp.right)})' return f'{inp.operator}({expression(inp.right)})'
if inp.operator == 'cast': if inp.operator == 'cast':

View File

@ -53,6 +53,22 @@ INSTANCES = {
'a=f32': stdlib_types.f32_floating_sqrt, 'a=f32': stdlib_types.f32_floating_sqrt,
'a=f64': stdlib_types.f64_floating_sqrt, 'a=f64': stdlib_types.f64_floating_sqrt,
}, },
type3classes.Fractional.methods['ceil']: {
'a=f32': stdlib_types.f32_fractional_ceil,
'a=f64': stdlib_types.f64_fractional_ceil,
},
type3classes.Fractional.methods['floor']: {
'a=f32': stdlib_types.f32_fractional_floor,
'a=f64': stdlib_types.f64_fractional_floor,
},
type3classes.Fractional.methods['trunc']: {
'a=f32': stdlib_types.f32_fractional_trunc,
'a=f64': stdlib_types.f64_fractional_trunc,
},
type3classes.Fractional.methods['nearest']: {
'a=f32': stdlib_types.f32_fractional_nearest,
'a=f64': stdlib_types.f64_fractional_nearest,
},
type3classes.Fractional.operators['/']: { type3classes.Fractional.operators['/']: {
'a=f32': stdlib_types.f32_fractional_div, 'a=f32': stdlib_types.f32_fractional_div,
'a=f64': stdlib_types.f64_fractional_div, 'a=f64': stdlib_types.f64_fractional_div,
@ -429,15 +445,6 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR
if inp.type3 == type3types.f32:
if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS:
wgn.add_statement(f'f32.{inp.operator}')
return
if inp.type3 == type3types.f64:
if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS:
wgn.add_statement(f'f64.{inp.operator}')
return
if inp.type3 == type3types.u32: if inp.type3 == type3types.u32:
if inp.operator == 'len': if inp.operator == 'len':
if inp.right.type3 == type3types.bytes: if inp.right.type3 == type3types.bytes:

View File

@ -10,7 +10,6 @@ from .type3 import typeclasses as type3typeclasses
from .type3 import types as type3types from .type3 import types as type3types
from .type3.types import PlaceholderForType, StructType3, Type3, Type3OrPlaceholder from .type3.types import PlaceholderForType, StructType3, Type3, Type3OrPlaceholder
WEBASSEMBLY_BUILTIN_FLOAT_OPS: Final = ('abs', 'ceil', 'floor', 'trunc', 'nearest', )
WEBASSEMBLY_BUILTIN_BYTES_OPS: Final = ('len', ) WEBASSEMBLY_BUILTIN_BYTES_OPS: Final = ('len', )
class Expression: class Expression:

View File

@ -6,7 +6,6 @@ from typing import Any, Dict, NoReturn, Union
from .exceptions import StaticError from .exceptions import StaticError
from .ourlang import ( from .ourlang import (
WEBASSEMBLY_BUILTIN_FLOAT_OPS,
AccessStructMember, AccessStructMember,
BinaryOp, BinaryOp,
ConstantBytes, ConstantBytes,
@ -490,14 +489,6 @@ class OurVisitor:
# FIXME: Defer struct de-allocation # FIXME: Defer struct de-allocation
func = module.functions[struct_constructor.name] func = module.functions[struct_constructor.name]
elif node.func.id in WEBASSEMBLY_BUILTIN_FLOAT_OPS:
if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
return UnaryOp(
'sqrt',
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]),
)
elif node.func.id == 'u32': elif node.func.id == 'u32':
if 1 != len(node.args): if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')

View File

@ -174,11 +174,38 @@ def f64_eq_not_equals(g: Generator) -> None:
## ### ## ###
## class Fractional ## class Fractional
def f32_fractional_ceil(g: Generator) -> None:
g.f32.ceil()
def f64_fractional_ceil(g: Generator) -> None:
g.f64.ceil()
def f32_fractional_floor(g: Generator) -> None:
g.f32.floor()
def f64_fractional_floor(g: Generator) -> None:
g.f64.floor()
def f32_fractional_trunc(g: Generator) -> None:
g.f32.trunc()
def f64_fractional_trunc(g: Generator) -> None:
g.f64.trunc()
def f32_fractional_nearest(g: Generator) -> None:
g.f32.nearest()
def f64_fractional_nearest(g: Generator) -> None:
g.f64.nearest()
def f32_fractional_div(g: Generator) -> None: def f32_fractional_div(g: Generator) -> None:
g.add_statement('f32.div') g.f32.div()
def f64_fractional_div(g: Generator) -> None: def f64_fractional_div(g: Generator) -> None:
g.add_statement('f64.div') g.f64.div()
## ###
## class Floating
def f32_floating_sqrt(g: Generator) -> None: def f32_floating_sqrt(g: Generator) -> None:
g.add_statement('f32.sqrt') g.add_statement('f32.sqrt')

View File

@ -113,7 +113,12 @@ Integral = Type3Class('Eq', ['a'], methods={
'div': 'a -> a -> a', 'div': 'a -> a -> a',
}, operators={}, inherited_classes=[NatNum]) }, operators={}, inherited_classes=[NatNum])
Fractional = Type3Class('Fractional', ['a'], methods={}, operators={ Fractional = Type3Class('Fractional', ['a'], methods={
'ceil': 'a -> a',
'floor': 'a -> a',
'trunc': 'a -> a',
'nearest': 'a -> a',
}, operators={
'/': 'a -> a -> a', '/': 'a -> a -> a',
}, inherited_classes=[NatNum]) }, inherited_classes=[NatNum])

View File

@ -4,44 +4,70 @@ from ..helpers import Suite
TYPE_LIST = ['f32', 'f64'] TYPE_LIST = ['f32', 'f64']
TEST_LIST = [
('10.0 / 8.0', 1.25, ),
# WebAssembly dictates that float division follows the IEEE rules
# https://www.w3.org/TR/wasm-core-1/#-hrefop-fdivmathrmfdiv_n-z_1-z_2
('10.0 / 0.0', float('+inf') , ),
('-10.0 / 0.0', float('-inf') , ),
( 'ceil(4.5)', 5.0, ),
( 'ceil(4.75)', 5.0, ),
( 'ceil(5.0)', 5.0, ),
( 'ceil(5.25)', 6.0, ),
( 'ceil(5.5)', 6.0, ),
('ceil(-4.5)', -4.0, ),
('ceil(-4.75)', -4.0, ),
('ceil(-5.0)' , -5.0, ),
('ceil(-5.25)', -5.0, ),
('ceil(-5.5)', -5.0, ),
( 'floor(4.5)', 4.0, ),
( 'floor(4.75)', 4.0, ),
( 'floor(5.0)', 5.0, ),
( 'floor(5.25)', 5.0, ),
( 'floor(5.5)', 5.0, ),
('floor(-4.5)', -5.0, ),
('floor(-4.75)', -5.0, ),
('floor(-5.0)' , -5.0, ),
('floor(-5.25)', -6.0, ),
('floor(-5.5)', -6.0, ),
( 'trunc(4.5)', 4.0, ),
( 'trunc(4.75)', 4.0, ),
( 'trunc(5.0)', 5.0, ),
( 'trunc(5.25)', 5.0, ),
( 'trunc(5.5)', 5.0, ),
('trunc(-4.5)', -4.0, ),
('trunc(-4.75)', -4.0, ),
('trunc(-5.0)' , -5.0, ),
('trunc(-5.25)', -5.0, ),
('trunc(-5.5)', -5.0, ),
( 'nearest(4.5)', 4.0, ),
( 'nearest(4.75)', 5.0, ),
( 'nearest(5.0)', 5.0, ),
( 'nearest(5.25)', 5.0, ),
( 'nearest(5.5)', 6.0, ),
('nearest(-4.5)', -4.0, ),
('nearest(-4.75)', -5.0, ),
('nearest(-5.0)', -5.0, ),
('nearest(-5.25)', -5.0, ),
('nearest(-5.5)', -6.0, ),
]
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.parametrize('type_', TYPE_LIST) @pytest.mark.parametrize('type_', TYPE_LIST)
def test_division_float(type_): @pytest.mark.parametrize('test_in,test_out', TEST_LIST)
def test_fractional(type_, test_in, test_out):
code_py = f""" code_py = f"""
@exported @exported
def testEntry() -> {type_}: def testEntry() -> {type_}:
return 10.0 / 8.0 return {test_in}
""" """
result = Suite(code_py).run_code() result = Suite(code_py).run_code()
assert 1.25 == result.returned_value assert test_out == result.returned_value
assert isinstance(result.returned_value, float) assert isinstance(result.returned_value, float)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', TYPE_LIST)
def test_division_float_follow_ieee_so_inf_pos(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return 10.0 / 0.0
"""
# WebAssembly dictates that float division follows the IEEE rules
# https://www.w3.org/TR/wasm-core-1/#-hrefop-fdivmathrmfdiv_n-z_1-z_2
result = Suite(code_py).run_code()
assert float('+inf') == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', TYPE_LIST)
def test_division_float_follow_ieee_so_inf_neg(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
return -10.0 / 0.0
"""
# WebAssembly dictates that float division follows the IEEE rules
# https://www.w3.org/TR/wasm-core-1/#-hrefop-fdivmathrmfdiv_n-z_1-z_2
result = Suite(code_py).run_code()
assert float('-inf') == result.returned_value