diff --git a/phasm/compiler.py b/phasm/compiler.py index 2e214a0..5040902 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -39,6 +39,16 @@ INSTANCES = { 'a=f32': stdlib_types.f32_eq_equals, 'a=f64': stdlib_types.f64_eq_equals, }, + type3classes.Fractional.operators['/']: { + 'a=f32': stdlib_types.f32_fractional_div, + 'a=f64': stdlib_types.f64_fractional_div, + }, + type3classes.Integral.methods['div']: { + 'a=u32': stdlib_types.u32_integral_div, + 'a=u64': stdlib_types.u64_integral_div, + 'a=i32': stdlib_types.i32_integral_div, + 'a=i64': stdlib_types.i64_integral_div, + }, type3classes.Num.operators['+']: { 'a=u32': stdlib_types.u32_num_add, 'a=u64': stdlib_types.u64_num_add, @@ -157,7 +167,6 @@ U32_OPERATOR_MAP = { '^': 'xor', '|': 'or', '&': 'and', - '/': 'div_u' # Division by zero is a trap and the program will panic } U64_OPERATOR_MAP = { @@ -170,7 +179,6 @@ U64_OPERATOR_MAP = { '^': 'xor', '|': 'or', '&': 'and', - '/': 'div_u' # Division by zero is a trap and the program will panic } I32_OPERATOR_MAP = { @@ -178,7 +186,6 @@ I32_OPERATOR_MAP = { '>': 'gt_s', '<=': 'le_s', '>=': 'ge_s', - '/': 'div_s' # Division by zero is a trap and the program will panic } I64_OPERATOR_MAP = { @@ -186,15 +193,6 @@ I64_OPERATOR_MAP = { '>': 'gt_s', '<=': 'le_s', '>=': 'ge_s', - '/': 'div_s' # Division by zero is a trap and the program will panic -} - -F32_OPERATOR_MAP = { - '/': 'div' # Division by zero is a trap and the program will panic -} - -F64_OPERATOR_MAP = { - '/': 'div' # Division by zero is a trap and the program will panic } def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) -> None: @@ -397,14 +395,6 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: if operator := I64_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return - if inp.type3 == type3types.f32: - if operator := F32_OPERATOR_MAP.get(inp.operator, None): - wgn.add_statement(f'f32.{operator}') - return - if inp.type3 == type3types.f64: - if operator := F64_OPERATOR_MAP.get(inp.operator, None): - wgn.add_statement(f'f64.{operator}') - return raise NotImplementedError(expression, inp.operator, inp.left.type3, inp.right.type3, inp.type3) @@ -439,6 +429,30 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: for arg in inp.arguments: expression(wgn, arg) + if isinstance(inp.function, type3classes.Type3ClassMethod): + # FIXME: Duplicate code with BinaryOp + type_var_map = {} + + for type_var, arg_expr in zip(inp.function.signature, inp.arguments + [inp]): + if not isinstance(type_var, type3classes.TypeVariable): + # Fixed type, not part of the lookup requirements + continue + + assert isinstance(arg_expr.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR + type_var_map[type_var] = arg_expr.type3 + + instance_key = ','.join( + f'{k.letter}={v.name}' + for k, v in type_var_map.items() + ) + + instance = INSTANCES.get(inp.function, {}).get(instance_key, None) + if instance is not None: + instance(wgn) + return + + raise NotImplementedError(inp.function, instance_key) + wgn.add_statement('call', '${}'.format(inp.function.name)) return diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 3c6dc76..ae242dd 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -6,7 +6,7 @@ from typing import Dict, Iterable, List, Optional, Union from typing_extensions import Final -from .type3 import typeclasses as type3classes +from .type3 import typeclasses as type3typeclasses from .type3 import types as type3types from .type3.types import PlaceholderForType, StructType3, Type3, Type3OrPlaceholder @@ -150,11 +150,11 @@ class BinaryOp(Expression): """ __slots__ = ('operator', 'left', 'right', ) - operator: Union[str, type3classes.Type3ClassMethod] + operator: Union[str, type3typeclasses.Type3ClassMethod] left: Expression right: Expression - def __init__(self, operator: Union[str, type3classes.Type3ClassMethod], left: Expression, right: Expression) -> None: + def __init__(self, operator: Union[str, type3typeclasses.Type3ClassMethod], left: Expression, right: Expression) -> None: super().__init__() self.operator = operator @@ -170,10 +170,10 @@ class FunctionCall(Expression): """ __slots__ = ('function', 'arguments', ) - function: 'Function' + function: Union['Function', type3typeclasses.Type3ClassMethod] arguments: List[Expression] - def __init__(self, function: 'Function') -> None: + def __init__(self, function: Union['Function', type3typeclasses.Type3ClassMethod]) -> None: super().__init__() self.function = function diff --git a/phasm/parser.py b/phasm/parser.py index 55b6ed4..61bdfa2 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -37,9 +37,18 @@ from .type3 import types as type3types PRELUDE_OPERATORS = { **type3typeclasses.Eq.operators, + **type3typeclasses.Fractional.operators, + **type3typeclasses.Integral.operators, **type3typeclasses.Num.operators, } +PRELUDE_METHODS = { + **type3typeclasses.Eq.methods, + **type3typeclasses.Fractional.methods, + **type3typeclasses.Integral.methods, + **type3typeclasses.Num.methods, +} + def phasm_parse(source: str) -> Module: """ Public method for parsing Phasm code into a Phasm Module @@ -465,7 +474,11 @@ class OurVisitor: if not isinstance(node.func.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if node.func.id in module.struct_definitions: + func: Union[Function, type3typeclasses.Type3ClassMethod] + + if node.func.id in PRELUDE_METHODS: + func = PRELUDE_METHODS[node.func.id] + elif node.func.id in module.struct_definitions: struct_definition = module.struct_definitions[node.func.id] struct_constructor = StructConstructor(struct_definition.struct_type3) @@ -524,13 +537,15 @@ class OurVisitor: func = module.functions[node.func.id] - if len(func.posonlyargs) != len(node.args): - _raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given') + exp_arg_count = len(func.posonlyargs) if isinstance(func, Function) else len(func.signature) - 1 + + if exp_arg_count != len(node.args): + _raise_static_error(node, f'Function {node.func.id} requires {exp_arg_count} arguments but {len(node.args)} are given') result = FunctionCall(func) result.arguments.extend( self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr) - for arg_expr, param in zip(node.args, func.posonlyargs) + for arg_expr in node.args ) return result diff --git a/phasm/stdlib/types.py b/phasm/stdlib/types.py index 11ed170..2aa906c 100644 --- a/phasm/stdlib/types.py +++ b/phasm/stdlib/types.py @@ -90,6 +90,24 @@ def f32_eq_equals(g: Generator) -> None: def f64_eq_equals(g: Generator) -> None: g.add_statement('f64.eq') +def f32_fractional_div(g: Generator) -> None: + g.add_statement('f32.div') + +def f64_fractional_div(g: Generator) -> None: + g.add_statement('f64.div') + +def u32_integral_div(g: Generator) -> None: + g.add_statement('i32.div_u') + +def u64_integral_div(g: Generator) -> None: + g.add_statement('i64.div_u') + +def i32_integral_div(g: Generator) -> None: + g.add_statement('i32.div_s') + +def i64_integral_div(g: Generator) -> None: + g.add_statement('i64.div_s') + def u32_num_add(g: Generator) -> None: g.add_statement('i32.add') diff --git a/phasm/type3/constraintsgenerator.py b/phasm/type3/constraintsgenerator.py index 9aa315d..8927747 100644 --- a/phasm/type3/constraintsgenerator.py +++ b/phasm/type3/constraintsgenerator.py @@ -95,10 +95,8 @@ def expression(ctx: Context, inp: ourlang.Expression) -> ConstraintGenerator: exp_type = type3types.LOOKUP_TABLE[sig_part.name] yield SameTypeConstraint(exp_type, arg_expr.type3) continue - return - if inp.operator in ('|', '&', '^', ): yield from expression(ctx, inp.left) yield from expression(ctx, inp.right) @@ -117,15 +115,6 @@ def expression(ctx: Context, inp: ourlang.Expression) -> ConstraintGenerator: comment=f'({inp.operator}) :: a -> a -> a') return - if inp.operator in ('/', ): - yield from expression(ctx, inp.left) - yield from expression(ctx, inp.right) - - yield MustImplementTypeClassConstraint('Fractional', inp.left.type3) - yield SameTypeConstraint(inp.left.type3, inp.right.type3, inp.type3, - comment=f'({inp.operator}) :: a -> a -> a') - return - if inp.operator == '==': yield from expression(ctx, inp.left) yield from expression(ctx, inp.right) @@ -151,6 +140,39 @@ def expression(ctx: Context, inp: ourlang.Expression) -> ConstraintGenerator: raise NotImplementedError(expression, inp) if isinstance(inp, ourlang.FunctionCall): + if isinstance(inp.function, type3typeclasses.Type3ClassMethod): + # FIXME: Duplicate code with BinaryOp + + type_var_map = { + x: type3types.PlaceholderForType([]) + for x in inp.function.signature + if isinstance(x, type3typeclasses.TypeVariable) + } + + for call_arg in inp.arguments: + yield from expression(ctx, call_arg) + + for type_var in inp.function.type3_class.args: + assert type_var in type_var_map # When can this happen? + + yield MustImplementTypeClassConstraint( + inp.function.type3_class, + type_var_map[type_var], + ) + + for sig_part, arg_expr in zip(inp.function.signature, inp.arguments + [inp]): + if isinstance(sig_part, type3typeclasses.TypeVariable): + yield SameTypeConstraint(type_var_map[sig_part], arg_expr.type3) + continue + + if isinstance(sig_part, type3typeclasses.TypeReference): + # On key error: We probably have to a lot of work to do refactoring + # the type lookups + exp_type = type3types.LOOKUP_TABLE[sig_part.name] + yield SameTypeConstraint(exp_type, arg_expr.type3) + continue + return + yield SameTypeConstraint(inp.function.returns_type3, inp.type3, comment=f'The type of a function call to {inp.function.name} is the same as the type that the function returns') diff --git a/phasm/type3/typeclasses.py b/phasm/type3/typeclasses.py index 4eb6adc..6670c34 100644 --- a/phasm/type3/typeclasses.py +++ b/phasm/type3/typeclasses.py @@ -88,6 +88,10 @@ Eq = Type3Class('Eq', ['a'], methods={}, operators={ '==': 'a -> a -> bool', }) +Fractional = Type3Class('Fractional', ['a'], methods={}, operators={ + '/': 'a -> a -> a', +}) + Integral = Type3Class('Eq', ['a'], methods={ 'div': 'a -> a -> a', }, operators={}) diff --git a/phasm/type3/types.py b/phasm/type3/types.py index 225e8d8..2b2ad9e 100644 --- a/phasm/type3/types.py +++ b/phasm/type3/types.py @@ -6,7 +6,7 @@ constraint generator works with. """ from typing import Any, Dict, Iterable, List, Optional, Protocol, Union -from .typeclasses import Eq, Num, Type3Class +from .typeclasses import Eq, Fractional, Integral, Num, Type3Class TYPE3_ASSERTION_ERROR = 'You must call phasm_type3 after calling phasm_parse before you can call any other method' @@ -242,28 +242,28 @@ The bool type, either True or False Suffixes with an underscores, as it's a Python builtin """ -u8 = PrimitiveType3('u8', [Eq]) +u8 = PrimitiveType3('u8', [Eq, Integral]) """ The unsigned 8-bit integer type. Operations on variables employ modular arithmetic, with modulus 2^8. """ -u32 = PrimitiveType3('u32', [Eq, Num]) +u32 = PrimitiveType3('u32', [Eq, Integral, Num]) """ The unsigned 32-bit integer type. Operations on variables employ modular arithmetic, with modulus 2^32. """ -u64 = PrimitiveType3('u64', [Eq, Num]) +u64 = PrimitiveType3('u64', [Eq, Integral, Num]) """ The unsigned 64-bit integer type. Operations on variables employ modular arithmetic, with modulus 2^64. """ -i8 = PrimitiveType3('i8', [Eq]) +i8 = PrimitiveType3('i8', [Eq, Integral]) """ The signed 8-bit integer type. @@ -271,7 +271,7 @@ Operations on variables employ modular arithmetic, with modulus 2^8, but with the middel point being 0. """ -i32 = PrimitiveType3('i32', [Eq, Num]) +i32 = PrimitiveType3('i32', [Eq, Integral, Num]) """ The unsigned 32-bit integer type. @@ -279,7 +279,7 @@ Operations on variables employ modular arithmetic, with modulus 2^32, but with the middel point being 0. """ -i64 = PrimitiveType3('i64', [Eq, Num]) +i64 = PrimitiveType3('i64', [Eq, Integral, Num]) """ The unsigned 64-bit integer type. @@ -287,12 +287,12 @@ Operations on variables employ modular arithmetic, with modulus 2^64, but with the middel point being 0. """ -f32 = PrimitiveType3('f32', [Eq, Num]) +f32 = PrimitiveType3('f32', [Eq, Fractional, Num]) """ A 32-bits IEEE 754 float, of 32 bits width. """ -f64 = PrimitiveType3('f64', [Eq, Num]) +f64 = PrimitiveType3('f64', [Eq, Fractional, Num]) """ A 32-bits IEEE 754 float, of 64 bits width. """