First attempt at ripping out old system

This breaks test_addition[u32], which is a good thing to chase next.
This commit is contained in:
Johan B.W. de Vries 2022-09-16 17:39:46 +02:00
parent 2d0daf4b90
commit 6f3d9a5bcc
6 changed files with 109 additions and 199 deletions

View File

@ -86,14 +86,8 @@ def expression(inp: ourlang.Expression) -> str:
""" """
Render: A Phasm expression Render: A Phasm expression
""" """
if isinstance(inp, ( if isinstance(inp, ourlang.ConstantPrimitive):
ourlang.ConstantUInt8, ourlang.ConstantUInt32, ourlang.ConstantUInt64, # Floats might not round trip if the original constant
ourlang.ConstantInt32, ourlang.ConstantInt64,
)):
return str(inp.value)
if isinstance(inp, (ourlang.ConstantFloat32, ourlang.ConstantFloat64, )):
# These might not round trip if the original constant
# could not fit in the given float type # could not fit in the given float type
return str(inp.value) return str(inp.value)

View File

@ -131,33 +131,25 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
""" """
Compile: Any expression Compile: Any expression
""" """
if isinstance(inp, ourlang.ConstantUInt8): if isinstance(inp, ourlang.ConstantPrimitive):
wgn.i32.const(inp.value) stp = typing.simplify(inp.type_var)
return if stp is None:
raise NotImplementedError(f'Constants with type {inp.type_var}')
if isinstance(inp, ourlang.ConstantUInt32): if stp == 'u8':
wgn.i32.const(inp.value) # No native u8 type - treat as i32, with caution
return wgn.i32.const(inp.value)
return
if isinstance(inp, ourlang.ConstantUInt64): if stp in ('i32', 'u32'):
wgn.i64.const(inp.value) wgn.i32.const(inp.value)
return return
if isinstance(inp, ourlang.ConstantInt32): if stp in ('i64', 'u64'):
wgn.i32.const(inp.value) wgn.i64.const(inp.value)
return return
if isinstance(inp, ourlang.ConstantInt64): raise NotImplementedError(f'Constants with type {stp}')
wgn.i64.const(inp.value)
return
if isinstance(inp, ourlang.ConstantFloat32):
wgn.f32.const(inp.value)
return
if isinstance(inp, ourlang.ConstantFloat64):
wgn.f64.const(inp.value)
return
if isinstance(inp, ourlang.VariableReference): if isinstance(inp, ourlang.VariableReference):
wgn.add_statement('local.get', '${}'.format(inp.variable.name)) wgn.add_statement('local.get', '${}'.format(inp.variable.name))

View File

@ -44,88 +44,15 @@ class Constant(Expression):
""" """
__slots__ = () __slots__ = ()
class ConstantUInt8(Constant): class ConstantPrimitive(Constant):
""" """
An UInt8 constant value expression within a statement An primitive constant value expression within a statement
""" """
__slots__ = ('value', ) __slots__ = ('value', )
value: int value: Union[int, float]
def __init__(self, type_: TypeUInt8, value: int) -> None: def __init__(self, value: Union[int, float]) -> None:
super().__init__(type_)
self.value = value
class ConstantUInt32(Constant):
"""
An UInt32 constant value expression within a statement
"""
__slots__ = ('value', )
value: int
def __init__(self, type_: TypeUInt32, value: int) -> None:
super().__init__(type_)
self.value = value
class ConstantUInt64(Constant):
"""
An UInt64 constant value expression within a statement
"""
__slots__ = ('value', )
value: int
def __init__(self, type_: TypeUInt64, value: int) -> None:
super().__init__(type_)
self.value = value
class ConstantInt32(Constant):
"""
An Int32 constant value expression within a statement
"""
__slots__ = ('value', )
value: int
def __init__(self, type_: TypeInt32, value: int) -> None:
super().__init__(type_)
self.value = value
class ConstantInt64(Constant):
"""
An Int64 constant value expression within a statement
"""
__slots__ = ('value', )
value: int
def __init__(self, type_: TypeInt64, value: int) -> None:
super().__init__(type_)
self.value = value
class ConstantFloat32(Constant):
"""
An Float32 constant value expression within a statement
"""
__slots__ = ('value', )
value: float
def __init__(self, type_: TypeFloat32, value: float) -> None:
super().__init__(type_)
self.value = value
class ConstantFloat64(Constant):
"""
An Float64 constant value expression within a statement
"""
__slots__ = ('value', )
value: float
def __init__(self, type_: TypeFloat64, value: float) -> None:
super().__init__(type_)
self.value = value self.value = value
class ConstantTuple(Constant): class ConstantTuple(Constant):
@ -134,9 +61,9 @@ class ConstantTuple(Constant):
""" """
__slots__ = ('value', ) __slots__ = ('value', )
value: List[Constant] value: List[ConstantPrimitive]
def __init__(self, type_: TypeTuple, value: List[Constant]) -> None: def __init__(self, type_: TypeTuple, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples?
super().__init__(type_) super().__init__(type_)
self.value = value self.value = value
@ -146,9 +73,9 @@ class ConstantStaticArray(Constant):
""" """
__slots__ = ('value', ) __slots__ = ('value', )
value: List[Constant] value: List[ConstantPrimitive]
def __init__(self, type_: TypeStaticArray, value: List[Constant]) -> None: def __init__(self, type_: TypeStaticArray, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays?
super().__init__(type_) super().__init__(type_)
self.value = value self.value = value
@ -455,10 +382,10 @@ class ModuleDataBlock:
""" """
__slots__ = ('data', 'address', ) __slots__ = ('data', 'address', )
data: List[Constant] data: List[ConstantPrimitive]
address: Optional[int] address: Optional[int]
def __init__(self, data: List[Constant]) -> None: def __init__(self, data: List[ConstantPrimitive]) -> None:
self.data = data self.data = data
self.address = None self.address = None

View File

@ -35,9 +35,7 @@ from .ourlang import (
AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember, AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember,
BinaryOp, BinaryOp,
Constant, Constant,
ConstantFloat32, ConstantFloat64, ConstantInt32, ConstantInt64, ConstantPrimitive, ConstantTuple, ConstantStaticArray,
ConstantUInt8, ConstantUInt32, ConstantUInt64,
ConstantTuple, ConstantStaticArray,
FunctionCall, FunctionCall,
StructConstructor, TupleConstructor, StructConstructor, TupleConstructor,
@ -211,18 +209,14 @@ class OurVisitor:
exp_type = self.visit_type(module, node.annotation) exp_type = self.visit_type(module, node.annotation)
if isinstance(exp_type, TypeInt32): if isinstance(node.value, ast.Constant):
if not isinstance(node.value, ast.Constant): return ModuleConstantDef(
_raise_static_error(node, 'Must be constant')
constant = ModuleConstantDef(
node.target.id, node.target.id,
node.lineno, node.lineno,
exp_type, exp_type,
self.visit_Module_Constant(module, exp_type, node.value), self.visit_Module_Constant(module, node.value),
None, None,
) )
return constant
if isinstance(exp_type, TypeTuple): if isinstance(exp_type, TypeTuple):
if not isinstance(node.value, ast.Tuple): if not isinstance(node.value, ast.Tuple):
@ -232,7 +226,7 @@ class OurVisitor:
_raise_static_error(node, 'Invalid number of tuple values') _raise_static_error(node, 'Invalid number of tuple values')
tuple_data = [ tuple_data = [
self.visit_Module_Constant(module, mem.type, arg_node) self.visit_Module_Constant(module, arg_node)
for arg_node, mem in zip(node.value.elts, exp_type.members) for arg_node, mem in zip(node.value.elts, exp_type.members)
if isinstance(arg_node, ast.Constant) if isinstance(arg_node, ast.Constant)
] ]
@ -260,7 +254,7 @@ class OurVisitor:
_raise_static_error(node, 'Invalid number of static array values') _raise_static_error(node, 'Invalid number of static array values')
static_array_data = [ static_array_data = [
self.visit_Module_Constant(module, exp_type.member_type, arg_node) self.visit_Module_Constant(module, arg_node)
for arg_node in node.value.elts for arg_node in node.value.elts
if isinstance(arg_node, ast.Constant) if isinstance(arg_node, ast.Constant)
] ]
@ -413,7 +407,7 @@ class OurVisitor:
if isinstance(node, ast.Constant): if isinstance(node, ast.Constant):
return self.visit_Module_Constant( return self.visit_Module_Constant(
module, exp_type, node, module, node,
) )
if isinstance(node, ast.Attribute): if isinstance(node, ast.Attribute):
@ -649,12 +643,15 @@ class OurVisitor:
) )
if isinstance(node_typ, TypeTuple): if isinstance(node_typ, TypeTuple):
if not isinstance(slice_expr, ConstantUInt32): if not isinstance(slice_expr, ConstantPrimitive):
_raise_static_error(node, 'Must subscript using a constant index') _raise_static_error(node, 'Must subscript using a constant index')
idx = slice_expr.value idx = slice_expr.value
if len(node_typ.members) <= idx: if not isinstance(idx, int):
_raise_static_error(node, 'Must subscript using a constant integer index')
if not (0 <= idx < len(node_typ.members)):
_raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}')
tuple_member = node_typ.members[idx] tuple_member = node_typ.members[idx]
@ -673,7 +670,7 @@ class OurVisitor:
if exp_type != node_typ.member_type: if exp_type != node_typ.member_type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}') _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}')
if not isinstance(slice_expr, ConstantInt32): if not isinstance(slice_expr, ConstantPrimitive):
return AccessStaticArrayMember( return AccessStaticArrayMember(
varref, varref,
node_typ, node_typ,
@ -682,7 +679,10 @@ class OurVisitor:
idx = slice_expr.value idx = slice_expr.value
if len(node_typ.members) <= idx: if not isinstance(idx, int):
_raise_static_error(node, 'Must subscript using an integer index')
if not (0 <= idx < len(node_typ.members)):
_raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}') _raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}')
static_array_member = node_typ.members[idx] static_array_member = node_typ.members[idx]
@ -695,73 +695,15 @@ class OurVisitor:
_raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}') _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}')
def visit_Module_Constant(self, module: Module, exp_type: TypeBase, node: ast.Constant) -> Constant: def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive:
del module del module
_not_implemented(node.kind is None, 'Constant.kind') _not_implemented(node.kind is None, 'Constant.kind')
if isinstance(exp_type, TypeUInt8): if isinstance(node.value, (int, float, )):
# if not isinstance(node.value, int): return ConstantPrimitive(node.value)
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < 0 or node.value > 255:
# _raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}')
return ConstantUInt8(exp_type, node.value) raise NotImplementedError(f'{node.value} as constant')
if isinstance(exp_type, TypeUInt32):
# if not isinstance(node.value, int):
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < 0 or node.value > 4294967295:
# _raise_static_error(node, 'Integer value out of range')
return ConstantUInt32(exp_type, node.value)
if isinstance(exp_type, TypeUInt64):
# if not isinstance(node.value, int):
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < 0 or node.value > 18446744073709551615:
# _raise_static_error(node, 'Integer value out of range')
return ConstantUInt64(exp_type, node.value)
if isinstance(exp_type, TypeInt32):
# if not isinstance(node.value, int):
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < -2147483648 or node.value > 2147483647:
# _raise_static_error(node, 'Integer value out of range')
return ConstantInt32(exp_type, node.value)
if isinstance(exp_type, TypeInt64):
# if not isinstance(node.value, int):
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < -9223372036854775808 or node.value > 9223372036854775807:
# _raise_static_error(node, 'Integer value out of range')
return ConstantInt64(exp_type, node.value)
if isinstance(exp_type, TypeFloat32):
if not isinstance(node.value, (float, int, )):
_raise_static_error(node, 'Expected float value')
# FIXME: Range check
return ConstantFloat32(exp_type, node.value)
if isinstance(exp_type, TypeFloat64):
if not isinstance(node.value, (float, int, )):
_raise_static_error(node, 'Expected float value')
# FIXME: Range check
return ConstantFloat64(exp_type, node.value)
raise NotImplementedError(f'{node} as const for type {exp_type}')
def visit_type(self, module: Module, node: ast.expr) -> TypeBase: def visit_type(self, module: Module, node: ast.expr) -> TypeBase:
if isinstance(node, ast.Constant): if isinstance(node, ast.Constant):

View File

@ -3,29 +3,33 @@ Type checks and enriches the given ast
""" """
from . import ourlang from . import ourlang
from .typing import Context, TypeConstraintBitWidth, TypeConstraintSigned, TypeVar from .typing import Context, TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeVar
def phasm_type(inp: ourlang.Module) -> None: def phasm_type(inp: ourlang.Module) -> None:
module(inp) module(inp)
def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar': def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar':
value = getattr(inp, 'value', None) if isinstance(inp, ourlang.ConstantPrimitive):
if isinstance(value, int):
result = ctx.new_var() result = ctx.new_var()
if not isinstance(inp.value, int):
raise NotImplementedError('Float constants in new type system')
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
# Need at least this many bits to store this constant value # Need at least this many bits to store this constant value
result.add_constraint(TypeConstraintBitWidth(minb=len(bin(value)) - 2)) result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2))
# Don't dictate anything about signedness - you can use a signed # Don't dictate anything about signedness - you can use a signed
# constant in an unsigned variable if the bits fit # constant in an unsigned variable if the bits fit
result.add_constraint(TypeConstraintSigned(None)) result.add_constraint(TypeConstraintSigned(None))
result.add_location(str(value)) result.add_location(str(inp.value))
inp.type_var = result inp.type_var = result
return result return result
raise NotImplementedError(constant, inp, value) raise NotImplementedError(constant, inp)
def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
if isinstance(inp, ourlang.Constant): if isinstance(inp, ourlang.Constant):

View File

@ -3,6 +3,8 @@ The phasm type system
""" """
from typing import Dict, Optional, List, Type from typing import Dict, Optional, List, Type
import enum
from .exceptions import TypingError from .exceptions import TypingError
class TypeBase: class TypeBase:
@ -218,6 +220,30 @@ class TypeConstraintBase:
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase': def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase':
raise NotImplementedError('narrow', self, other) raise NotImplementedError('narrow', self, other)
class TypeConstraintPrimitive(TypeConstraintBase):
__slots__ = ('primitive', )
class Primitive(enum.Enum):
INT = 0
FLOAT = 1
primitive: Primitive
def __init__(self, primitive: Primitive) -> None:
self.primitive = primitive
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive':
if not isinstance(other, TypeConstraintPrimitive):
raise Exception('Invalid comparison')
if self.primitive != other.primitive:
raise TypingNarrowProtoError('Primitive does not match')
return TypeConstraintPrimitive(self.primitive)
def __repr__(self) -> str:
return f'Primitive={self.primitive.name}'
class TypeConstraintSigned(TypeConstraintBase): class TypeConstraintSigned(TypeConstraintBase):
__slots__ = ('signed', ) __slots__ = ('signed', )
@ -326,3 +352,28 @@ class Context:
r.locations = newtypevar.locations r.locations = newtypevar.locations
return return
def simplify(inp: TypeVar) -> Optional[str]:
tc_prim = inp.constraints.get(TypeConstraintPrimitive)
tc_bits = inp.constraints.get(TypeConstraintBitWidth)
tc_sign = inp.constraints.get(TypeConstraintSigned)
if tc_prim is None:
return None
assert isinstance(tc_prim, TypeConstraintPrimitive) # type hint
primitive = tc_prim.primitive
if primitive is TypeConstraintPrimitive.Primitive.INT:
if tc_bits is None or tc_sign is None:
return None
assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint
assert isinstance(tc_sign, TypeConstraintSigned) # type hint
if tc_sign.signed is None or tc_bits.minb != tc_bits.maxb or tc_bits.minb not in (8, 32, 64):
return None
base = 'i' if tc_sign.signed else 'u'
return f'{base}{tc_bits.minb}'
return None