Implements the IO type constructor and Monad type class

This commit is contained in:
Johan B.W. de Vries 2025-08-23 15:33:23 +02:00
parent 71691d68e9
commit 3d6d279408
11 changed files with 193 additions and 8 deletions

View File

@ -45,6 +45,7 @@ class BuildBase[G]:
__slots__ = (
'dynamic_array_type5_constructor',
'function_type5_constructor',
'io_type5_constructor',
'static_array_type5_constructor',
'tuple_type5_constructor_map',
@ -83,6 +84,18 @@ class BuildBase[G]:
See type5_make_function and type5_is_function.
"""
io_type5_constructor: type5typeexpr.TypeConstructor
"""
Constructor for IO.
An IO function is a function that can have side effects.
It can do input or output. Other functions cannot have
side effects, and can only return a value based on the
input given.
See type5_make_io and type5_is_io.
"""
static_array_type5_constructor: type5typeexpr.TypeConstructor
"""
Constructor for arrays of compiled time determined length.
@ -207,6 +220,7 @@ class BuildBase[G]:
self.dynamic_array_type5_constructor = type5typeexpr.TypeConstructor(kind=S >> S, name="dynamic_array")
self.function_type5_constructor = type5typeexpr.TypeConstructor(kind=S >> (S >> S), name="function")
self.io_type5_constructor = type5typeexpr.TypeConstructor(kind=S >> S, name="IO")
self.static_array_type5_constructor = type5typeexpr.TypeConstructor(kind=N >> (S >> S), name='static_array')
self.tuple_type5_constructor_map = {}
@ -344,6 +358,19 @@ class BuildBase[G]:
return my_args + more_args
def type5_make_io(self, arg: type5typeexpr.TypeExpr) -> type5typeexpr.TypeApplication:
return type5typeexpr.TypeApplication(
constructor=self.io_type5_constructor,
argument=arg
)
def type5_is_io(self, typeexpr: type5typeexpr.TypeExpr | type5constrainedexpr.ConstrainedExpr) -> type5typeexpr.TypeExpr | None:
if not isinstance(typeexpr, type5typeexpr.TypeApplication):
return None
if typeexpr.constructor != self.io_type5_constructor:
return None
return typeexpr.argument
def type5_make_tuple(self, args: Sequence[type5typeexpr.TypeExpr]) -> type5typeexpr.TypeApplication:
if not args:
raise TypeError("Tuples must at least one field")

View File

@ -22,6 +22,7 @@ from .typeclasses import (
fractional,
integral,
intnum,
monad,
natnum,
ord,
promotable,
@ -68,6 +69,7 @@ class BuildDefault(BuildBase[Generator]):
integral,
foldable, subscriptable,
sized,
monad,
]
for tc in tc_list:

View File

@ -0,0 +1,26 @@
"""
The Monad type class is defined for type constructors that cause one thing to happen /after/ another.
"""
from __future__ import annotations
from typing import Any
from ...type5.constrainedexpr import ConstrainedExpr
from ...type5.kindexpr import Star
from ...type5.typeexpr import TypeVariable
from ...typeclass import TypeClass, TypeClassConstraint
from ...wasmgenerator import Generator as WasmGenerator
from ..base import BuildBase
def load(build: BuildBase[Any]) -> None:
a = TypeVariable(kind=Star(), name='a')
Monad = TypeClass('Monad', (a, ), methods={}, operators={})
build.register_type_class(Monad)
def wasm(build: BuildBase[WasmGenerator]) -> None:
Monad = build.type_classes['Monad']
build.instance_type_class(Monad, build.io_type5_constructor)

View File

@ -36,6 +36,10 @@ class BuildTypeRouter[T](TypeRouter[T]):
if fn_args is not None:
return self.when_function(fn_args)
io_args = self.build.type5_is_io(typ)
if io_args is not None:
return self.when_io(io_args)
sa_args = self.build.type5_is_static_array(typ)
if sa_args is not None:
sa_len, sa_typ = sa_args
@ -59,6 +63,9 @@ class BuildTypeRouter[T](TypeRouter[T]):
def when_function(self, fn_args: list[TypeExpr]) -> T:
raise NotImplementedError
def when_io(self, io_arg: TypeExpr) -> T:
raise NotImplementedError
def when_struct(self, typ: Record) -> T:
raise NotImplementedError
@ -94,6 +101,9 @@ class TypeName(BuildTypeRouter[str]):
def when_function(self, fn_args: list[TypeExpr]) -> str:
return 'Callable[' + ', '.join(map(self, fn_args)) + ']'
def when_io(self, io_arg: TypeExpr) -> str:
return 'IO[' + self(io_arg) + ']'
def when_static_array(self, sa_len: int, sa_typ: TypeExpr) -> str:
return f'{self(sa_typ)}[{sa_len}]'

View File

@ -124,6 +124,10 @@ def statement(inp: ourlang.Statement) -> Statements:
yield ''
return
if isinstance(inp, ourlang.StatementCall):
yield expression(inp.call)
return
raise NotImplementedError(statement, inp)
def function(mod: ourlang.Module[Any], inp: ourlang.Function) -> str:

View File

@ -43,6 +43,11 @@ def type5(mod: ourlang.Module[WasmGenerator], inp: TypeExpr) -> wasm.WasmType:
Types are used for example in WebAssembly function parameters
and return types.
"""
io_arg = mod.build.type5_is_io(inp)
if io_arg is not None:
# IO is type constructor that only exists on the typing layer
inp = io_arg
typ_info = mod.build.type_info_map.get(inp.name)
if typ_info is None:
typ_info = mod.build.type_info_constructed
@ -404,6 +409,9 @@ def statement_if(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun: ou
# for stat in inp.else_statements:
# statement(wgn, stat)
def statement_call(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun: ourlang.Function, inp: ourlang.StatementCall) -> None:
expression(wgn, mod, inp.call)
def statement(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun: ourlang.Function, inp: ourlang.Statement) -> None:
"""
Compile: any statement
@ -416,6 +424,10 @@ def statement(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun: ourla
statement_if(wgn, mod, fun, inp)
return
if isinstance(inp, ourlang.StatementCall):
statement_call(wgn, mod, fun, inp)
return
if isinstance(inp, ourlang.StatementPass):
return

View File

@ -282,6 +282,24 @@ class StatementReturn(Statement):
def __repr__(self) -> str:
return f'StatementReturn({repr(self.value)})'
class StatementCall(Statement):
"""
A function call within a function.
Executing is deferred to the given function until it completes.
"""
__slots__ = ('call')
call: FunctionCall
def __init__(self, call: FunctionCall, sourceref: SourceRef) -> None:
super().__init__(sourceref=sourceref)
self.call = call
def __repr__(self) -> str:
return f'StatementCall({repr(self.call)})'
class StatementIf(Statement):
"""
An if statement within a function

View File

@ -25,6 +25,7 @@ from .ourlang import (
ModuleDataBlock,
SourceRef,
Statement,
StatementCall,
StatementIf,
StatementPass,
StatementReturn,
@ -361,6 +362,9 @@ class OurVisitor[G]:
return result
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
return StatementCall(self.visit_Module_FunctionDef_Call(module, function, our_locals, node.value), srf(module, node))
if isinstance(node, ast.Pass):
return StatementPass(srf(module, node))
@ -484,7 +488,7 @@ class OurVisitor[G]:
raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module[G], function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall]:
def visit_Module_FunctionDef_Call(self, module: Module[G], function: Function, our_locals: OurLocals, node: ast.Call) -> FunctionCall:
if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet?
@ -647,6 +651,15 @@ class OurVisitor[G]:
for e in func_arg_types
])
if isinstance(node.value, ast.Name) and node.value.id == 'IO':
assert isinstance(node.slice, ast.Name) or (isinstance(node.slice, ast.Tuple) and len(node.slice.elts) == 0)
return module.build.type5_make_io(
self.visit_type5(module, node.slice)
)
# TODO: This u32[...] business is messing up the other type constructors
if isinstance(node.slice, ast.Slice):
_raise_static_error(node, 'Must subscript using an index')
@ -672,6 +685,9 @@ class OurVisitor[G]:
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if not node.elts:
return module.build.unit_type5
return module.build.type5_make_tuple(
[self.visit_type5(module, elt) for elt in node.elts],
)

View File

@ -15,7 +15,7 @@ from .constraints import (
TypeClassInstanceExistsConstraint,
UnifyTypesConstraint,
)
from .kindexpr import KindExpr
from .kindexpr import KindExpr, Star
from .typeexpr import TypeApplication, TypeExpr, TypeVariable, is_concrete
ConstraintGenerator = Generator[ConstraintBase, None, None]
@ -233,12 +233,13 @@ def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement
if fun.type5 is None:
raise NotImplementedError("Deducing function type - you'll have to annotate it.")
if isinstance(fun.type5, TypeApplication):
args = ctx.build.type5_is_function(fun.type5)
assert args is not None
type5 = args[-1]
else:
type5 = fun.type5.expr if isinstance(fun.type5, ConstrainedExpr) else fun.type5
args = ctx.build.type5_is_function(fun.type5)
assert args is not None
type5 = args[-1]
# This is a hack to allow return statement in pure and non pure functions
if (io_arg := ctx.build.type5_is_io(type5)):
type5 = io_arg
yield from expression(ctx, inp.value, phft)
yield UnifyTypesConstraint(ctx, inp.sourceref, type5, phft, prefix=f'{fun.name} returns')
@ -256,6 +257,27 @@ def statement_if(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementIf)
for stmt in inp.else_statements:
yield from statement(ctx, fun, stmt)
def statement_call(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementCall) -> ConstraintGenerator:
if fun.type5 is None:
raise NotImplementedError("Deducing function type - you'll have to annotate it.")
fn_args = ctx.build.type5_is_function(fun.type5)
assert fn_args is not None
fn_ret = fn_args[-1]
call_phft = ctx.make_placeholder(inp.call)
S = Star()
t_phft = ctx.make_placeholder(kind=S >> S)
a_phft = ctx.make_placeholder(kind=S)
yield from expression_function_call(ctx, inp.call, call_phft)
yield TypeClassInstanceExistsConstraint(ctx, inp.sourceref, 'Monad', [t_phft])
yield UnifyTypesConstraint(ctx, inp.sourceref, TypeApplication(constructor=t_phft, argument=a_phft), fn_ret)
yield UnifyTypesConstraint(ctx, inp.sourceref, TypeApplication(constructor=t_phft, argument=ctx.build.unit_type5), call_phft)
def statement(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement) -> ConstraintGenerator:
if isinstance(inp, ourlang.StatementReturn):
yield from statement_return(ctx, fun, inp)
@ -265,12 +287,18 @@ def statement(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement) -> Co
yield from statement_if(ctx, fun, inp)
return
if isinstance(inp, ourlang.StatementCall):
yield from statement_call(ctx, fun, inp)
return
raise NotImplementedError(inp)
def function(ctx: Context, inp: ourlang.Function) -> ConstraintGenerator:
for stmt in inp.statements:
yield from statement(ctx, inp, stmt)
# TODO: If function is imported or exported, it should be an IO[..] function
def module_constant_def(ctx: Context, inp: ourlang.ModuleConstantDef) -> ConstraintGenerator:
phft = ctx.make_placeholder(inp.constant)

View File

@ -139,6 +139,10 @@ class Extractor(BuildTypeRouter[ExtractorFunc]):
return DynamicArrayExtractor(self.access, self(da_arg))
def when_io(self, io_arg: TypeExpr) -> ExtractorFunc:
# IO is a type only annotation, it is not related to allocation
return self(io_arg)
def when_static_array(self, sa_len: int, sa_typ: TypeExpr) -> ExtractorFunc:
return StaticArrayExtractor(self.access, sa_len, self(sa_typ))

View File

@ -0,0 +1,38 @@
import pytest
from ..helpers import Suite
@pytest.mark.integration_test
def test_io_use_type_class():
code_py = f"""
@exported
def testEntry() -> IO[u32]:
return 4
"""
result = Suite(code_py).run_code()
assert 4 == result.returned_value
@pytest.mark.integration_test
def test_io_call_io_function():
code_py = f"""
@imported
def log(val: u32) -> IO[()]:
pass
@exported
def testEntry() -> IO[u32]:
log(123)
return 4
"""
log_history: list[Any] = []
def my_log(val: int) -> None:
log_history.append(val)
result = Suite(code_py).run_code(imports={
'log': my_log,
})
assert 4 == result.returned_value
assert [123] == log_history