diff --git a/phasm/codestyle.py b/phasm/codestyle.py index d0e9c70..bbb2eff 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -86,14 +86,8 @@ def expression(inp: ourlang.Expression) -> str: """ Render: A Phasm expression """ - if isinstance(inp, ( - ourlang.ConstantUInt8, ourlang.ConstantUInt32, ourlang.ConstantUInt64, - ourlang.ConstantInt32, ourlang.ConstantInt64, - )): - return str(inp.value) - - if isinstance(inp, (ourlang.ConstantFloat32, ourlang.ConstantFloat64, )): - # These might not round trip if the original constant + if isinstance(inp, ourlang.ConstantPrimitive): + # Floats might not round trip if the original constant # could not fit in the given float type return str(inp.value) diff --git a/phasm/compiler.py b/phasm/compiler.py index d36561d..6505982 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -131,33 +131,25 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: """ Compile: Any expression """ - if isinstance(inp, ourlang.ConstantUInt8): - wgn.i32.const(inp.value) - return + if isinstance(inp, ourlang.ConstantPrimitive): + stp = typing.simplify(inp.type_var) + if stp is None: + raise NotImplementedError(f'Constants with type {inp.type_var}') - if isinstance(inp, ourlang.ConstantUInt32): - wgn.i32.const(inp.value) - return + if stp == 'u8': + # No native u8 type - treat as i32, with caution + wgn.i32.const(inp.value) + return - if isinstance(inp, ourlang.ConstantUInt64): - wgn.i64.const(inp.value) - return + if stp in ('i32', 'u32'): + wgn.i32.const(inp.value) + return - if isinstance(inp, ourlang.ConstantInt32): - wgn.i32.const(inp.value) - return + if stp in ('i64', 'u64'): + wgn.i64.const(inp.value) + return - if isinstance(inp, ourlang.ConstantInt64): - wgn.i64.const(inp.value) - return - - if isinstance(inp, ourlang.ConstantFloat32): - wgn.f32.const(inp.value) - return - - if isinstance(inp, ourlang.ConstantFloat64): - wgn.f64.const(inp.value) - return + raise NotImplementedError(f'Constants with type {stp}') if isinstance(inp, ourlang.VariableReference): wgn.add_statement('local.get', '${}'.format(inp.variable.name)) diff --git a/phasm/ourlang.py b/phasm/ourlang.py index efc1a66..6496a2e 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -44,88 +44,15 @@ class Constant(Expression): """ __slots__ = () -class ConstantUInt8(Constant): +class ConstantPrimitive(Constant): """ - An UInt8 constant value expression within a statement + An primitive constant value expression within a statement """ __slots__ = ('value', ) - value: int + value: Union[int, float] - def __init__(self, type_: TypeUInt8, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantUInt32(Constant): - """ - An UInt32 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeUInt32, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantUInt64(Constant): - """ - An UInt64 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeUInt64, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantInt32(Constant): - """ - An Int32 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeInt32, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantInt64(Constant): - """ - An Int64 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeInt64, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantFloat32(Constant): - """ - An Float32 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: float - - def __init__(self, type_: TypeFloat32, value: float) -> None: - super().__init__(type_) - self.value = value - -class ConstantFloat64(Constant): - """ - An Float64 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: float - - def __init__(self, type_: TypeFloat64, value: float) -> None: - super().__init__(type_) + def __init__(self, value: Union[int, float]) -> None: self.value = value class ConstantTuple(Constant): @@ -134,9 +61,9 @@ class ConstantTuple(Constant): """ __slots__ = ('value', ) - value: List[Constant] + value: List[ConstantPrimitive] - def __init__(self, type_: TypeTuple, value: List[Constant]) -> None: + def __init__(self, type_: TypeTuple, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples? super().__init__(type_) self.value = value @@ -146,9 +73,9 @@ class ConstantStaticArray(Constant): """ __slots__ = ('value', ) - value: List[Constant] + value: List[ConstantPrimitive] - def __init__(self, type_: TypeStaticArray, value: List[Constant]) -> None: + def __init__(self, type_: TypeStaticArray, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays? super().__init__(type_) self.value = value @@ -455,10 +382,10 @@ class ModuleDataBlock: """ __slots__ = ('data', 'address', ) - data: List[Constant] + data: List[ConstantPrimitive] address: Optional[int] - def __init__(self, data: List[Constant]) -> None: + def __init__(self, data: List[ConstantPrimitive]) -> None: self.data = data self.address = None diff --git a/phasm/parser.py b/phasm/parser.py index 2aa25fc..8b1b3a2 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -35,9 +35,7 @@ from .ourlang import ( AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember, BinaryOp, Constant, - ConstantFloat32, ConstantFloat64, ConstantInt32, ConstantInt64, - ConstantUInt8, ConstantUInt32, ConstantUInt64, - ConstantTuple, ConstantStaticArray, + ConstantPrimitive, ConstantTuple, ConstantStaticArray, FunctionCall, StructConstructor, TupleConstructor, @@ -211,18 +209,14 @@ class OurVisitor: exp_type = self.visit_type(module, node.annotation) - if isinstance(exp_type, TypeInt32): - if not isinstance(node.value, ast.Constant): - _raise_static_error(node, 'Must be constant') - - constant = ModuleConstantDef( + if isinstance(node.value, ast.Constant): + return ModuleConstantDef( node.target.id, node.lineno, exp_type, - self.visit_Module_Constant(module, exp_type, node.value), + self.visit_Module_Constant(module, node.value), None, ) - return constant if isinstance(exp_type, TypeTuple): if not isinstance(node.value, ast.Tuple): @@ -232,7 +226,7 @@ class OurVisitor: _raise_static_error(node, 'Invalid number of tuple values') tuple_data = [ - self.visit_Module_Constant(module, mem.type, arg_node) + self.visit_Module_Constant(module, arg_node) for arg_node, mem in zip(node.value.elts, exp_type.members) if isinstance(arg_node, ast.Constant) ] @@ -260,7 +254,7 @@ class OurVisitor: _raise_static_error(node, 'Invalid number of static array values') static_array_data = [ - self.visit_Module_Constant(module, exp_type.member_type, arg_node) + self.visit_Module_Constant(module, arg_node) for arg_node in node.value.elts if isinstance(arg_node, ast.Constant) ] @@ -413,7 +407,7 @@ class OurVisitor: if isinstance(node, ast.Constant): return self.visit_Module_Constant( - module, exp_type, node, + module, node, ) if isinstance(node, ast.Attribute): @@ -649,12 +643,15 @@ class OurVisitor: ) if isinstance(node_typ, TypeTuple): - if not isinstance(slice_expr, ConstantUInt32): + if not isinstance(slice_expr, ConstantPrimitive): _raise_static_error(node, 'Must subscript using a constant index') idx = slice_expr.value - if len(node_typ.members) <= idx: + if not isinstance(idx, int): + _raise_static_error(node, 'Must subscript using a constant integer index') + + if not (0 <= idx < len(node_typ.members)): _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') tuple_member = node_typ.members[idx] @@ -673,7 +670,7 @@ class OurVisitor: if exp_type != node_typ.member_type: _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}') - if not isinstance(slice_expr, ConstantInt32): + if not isinstance(slice_expr, ConstantPrimitive): return AccessStaticArrayMember( varref, node_typ, @@ -682,7 +679,10 @@ class OurVisitor: idx = slice_expr.value - if len(node_typ.members) <= idx: + if not isinstance(idx, int): + _raise_static_error(node, 'Must subscript using an integer index') + + if not (0 <= idx < len(node_typ.members)): _raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}') static_array_member = node_typ.members[idx] @@ -695,73 +695,15 @@ class OurVisitor: _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}') - def visit_Module_Constant(self, module: Module, exp_type: TypeBase, node: ast.Constant) -> Constant: + def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive: del module _not_implemented(node.kind is None, 'Constant.kind') - if isinstance(exp_type, TypeUInt8): - # if not isinstance(node.value, int): - # _raise_static_error(node, 'Expected integer value') - # - # if node.value < 0 or node.value > 255: - # _raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}') + if isinstance(node.value, (int, float, )): + return ConstantPrimitive(node.value) - return ConstantUInt8(exp_type, node.value) - - if isinstance(exp_type, TypeUInt32): - # if not isinstance(node.value, int): - # _raise_static_error(node, 'Expected integer value') - # - # if node.value < 0 or node.value > 4294967295: - # _raise_static_error(node, 'Integer value out of range') - - return ConstantUInt32(exp_type, node.value) - - if isinstance(exp_type, TypeUInt64): - # if not isinstance(node.value, int): - # _raise_static_error(node, 'Expected integer value') - # - # if node.value < 0 or node.value > 18446744073709551615: - # _raise_static_error(node, 'Integer value out of range') - - return ConstantUInt64(exp_type, node.value) - - if isinstance(exp_type, TypeInt32): - # if not isinstance(node.value, int): - # _raise_static_error(node, 'Expected integer value') - # - # if node.value < -2147483648 or node.value > 2147483647: - # _raise_static_error(node, 'Integer value out of range') - - return ConstantInt32(exp_type, node.value) - - if isinstance(exp_type, TypeInt64): - # if not isinstance(node.value, int): - # _raise_static_error(node, 'Expected integer value') - # - # if node.value < -9223372036854775808 or node.value > 9223372036854775807: - # _raise_static_error(node, 'Integer value out of range') - - return ConstantInt64(exp_type, node.value) - - if isinstance(exp_type, TypeFloat32): - if not isinstance(node.value, (float, int, )): - _raise_static_error(node, 'Expected float value') - - # FIXME: Range check - - return ConstantFloat32(exp_type, node.value) - - if isinstance(exp_type, TypeFloat64): - if not isinstance(node.value, (float, int, )): - _raise_static_error(node, 'Expected float value') - - # FIXME: Range check - - return ConstantFloat64(exp_type, node.value) - - raise NotImplementedError(f'{node} as const for type {exp_type}') + raise NotImplementedError(f'{node.value} as constant') def visit_type(self, module: Module, node: ast.expr) -> TypeBase: if isinstance(node, ast.Constant): diff --git a/phasm/typer.py b/phasm/typer.py index 97c5e44..e835c96 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -3,29 +3,33 @@ Type checks and enriches the given ast """ from . import ourlang -from .typing import Context, TypeConstraintBitWidth, TypeConstraintSigned, TypeVar +from .typing import Context, TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeVar def phasm_type(inp: ourlang.Module) -> None: module(inp) def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar': - value = getattr(inp, 'value', None) - if isinstance(value, int): + if isinstance(inp, ourlang.ConstantPrimitive): result = ctx.new_var() + if not isinstance(inp.value, int): + raise NotImplementedError('Float constants in new type system') + + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + # Need at least this many bits to store this constant value - result.add_constraint(TypeConstraintBitWidth(minb=len(bin(value)) - 2)) + result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2)) # Don't dictate anything about signedness - you can use a signed # constant in an unsigned variable if the bits fit result.add_constraint(TypeConstraintSigned(None)) - result.add_location(str(value)) + result.add_location(str(inp.value)) inp.type_var = result return result - raise NotImplementedError(constant, inp, value) + raise NotImplementedError(constant, inp) def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': if isinstance(inp, ourlang.Constant): diff --git a/phasm/typing.py b/phasm/typing.py index 65bb095..72a5827 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -3,6 +3,8 @@ The phasm type system """ from typing import Dict, Optional, List, Type +import enum + from .exceptions import TypingError class TypeBase: @@ -218,6 +220,30 @@ class TypeConstraintBase: def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase': raise NotImplementedError('narrow', self, other) +class TypeConstraintPrimitive(TypeConstraintBase): + __slots__ = ('primitive', ) + + class Primitive(enum.Enum): + INT = 0 + FLOAT = 1 + + primitive: Primitive + + def __init__(self, primitive: Primitive) -> None: + self.primitive = primitive + + def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive': + if not isinstance(other, TypeConstraintPrimitive): + raise Exception('Invalid comparison') + + if self.primitive != other.primitive: + raise TypingNarrowProtoError('Primitive does not match') + + return TypeConstraintPrimitive(self.primitive) + + def __repr__(self) -> str: + return f'Primitive={self.primitive.name}' + class TypeConstraintSigned(TypeConstraintBase): __slots__ = ('signed', ) @@ -326,3 +352,28 @@ class Context: r.locations = newtypevar.locations return + +def simplify(inp: TypeVar) -> Optional[str]: + tc_prim = inp.constraints.get(TypeConstraintPrimitive) + tc_bits = inp.constraints.get(TypeConstraintBitWidth) + tc_sign = inp.constraints.get(TypeConstraintSigned) + + if tc_prim is None: + return None + + assert isinstance(tc_prim, TypeConstraintPrimitive) # type hint + primitive = tc_prim.primitive + if primitive is TypeConstraintPrimitive.Primitive.INT: + if tc_bits is None or tc_sign is None: + return None + + assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint + assert isinstance(tc_sign, TypeConstraintSigned) # type hint + + if tc_sign.signed is None or tc_bits.minb != tc_bits.maxb or tc_bits.minb not in (8, 32, 64): + return None + + base = 'i' if tc_sign.signed else 'u' + return f'{base}{tc_bits.minb}' + + return None