Compare commits

...

5 Commits

Author SHA1 Message Date
Johan B.W. de Vries
6f3d9a5bcc First attempt at ripping out old system
This breaks test_addition[u32], which is a good thing to chase next.
2022-09-16 17:39:46 +02:00
Johan B.W. de Vries
2d0daf4b90 Fixes 2022-09-16 17:04:13 +02:00
Johan B.W. de Vries
7669f3cbca More framework stuff 2022-09-16 17:01:23 +02:00
Johan B.W. de Vries
48e16c38b9 FunctionParam is a class, more framework stuff 2022-09-16 16:43:40 +02:00
Johan B.W. de Vries
7acb2bd8e6 Framework sketch 2022-09-16 15:54:24 +02:00
10 changed files with 451 additions and 226 deletions

View File

@ -86,14 +86,8 @@ def expression(inp: ourlang.Expression) -> str:
""" """
Render: A Phasm expression Render: A Phasm expression
""" """
if isinstance(inp, ( if isinstance(inp, ourlang.ConstantPrimitive):
ourlang.ConstantUInt8, ourlang.ConstantUInt32, ourlang.ConstantUInt64, # Floats might not round trip if the original constant
ourlang.ConstantInt32, ourlang.ConstantInt64,
)):
return str(inp.value)
if isinstance(inp, (ourlang.ConstantFloat32, ourlang.ConstantFloat64, )):
# These might not round trip if the original constant
# could not fit in the given float type # could not fit in the given float type
return str(inp.value) return str(inp.value)
@ -104,7 +98,7 @@ def expression(inp: ourlang.Expression) -> str:
) + ', )' ) + ', )'
if isinstance(inp, ourlang.VariableReference): if isinstance(inp, ourlang.VariableReference):
return str(inp.name) return str(inp.variable.name)
if isinstance(inp, ourlang.UnaryOp): if isinstance(inp, ourlang.UnaryOp):
if ( if (
@ -193,8 +187,8 @@ def function(inp: ourlang.Function) -> str:
result += '@imported\n' result += '@imported\n'
args = ', '.join( args = ', '.join(
f'{x}: {type_(y)}' f'{p.name}: {type_(p.type)}'
for x, y in inp.posonlyargs for p in inp.posonlyargs
) )
result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n' result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n'

View File

@ -131,36 +131,28 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
""" """
Compile: Any expression Compile: Any expression
""" """
if isinstance(inp, ourlang.ConstantUInt8): if isinstance(inp, ourlang.ConstantPrimitive):
wgn.i32.const(inp.value) stp = typing.simplify(inp.type_var)
return if stp is None:
raise NotImplementedError(f'Constants with type {inp.type_var}')
if isinstance(inp, ourlang.ConstantUInt32): if stp == 'u8':
wgn.i32.const(inp.value) # No native u8 type - treat as i32, with caution
return wgn.i32.const(inp.value)
return
if isinstance(inp, ourlang.ConstantUInt64): if stp in ('i32', 'u32'):
wgn.i64.const(inp.value) wgn.i32.const(inp.value)
return return
if isinstance(inp, ourlang.ConstantInt32): if stp in ('i64', 'u64'):
wgn.i32.const(inp.value) wgn.i64.const(inp.value)
return return
if isinstance(inp, ourlang.ConstantInt64): raise NotImplementedError(f'Constants with type {stp}')
wgn.i64.const(inp.value)
return
if isinstance(inp, ourlang.ConstantFloat32):
wgn.f32.const(inp.value)
return
if isinstance(inp, ourlang.ConstantFloat64):
wgn.f64.const(inp.value)
return
if isinstance(inp, ourlang.VariableReference): if isinstance(inp, ourlang.VariableReference):
wgn.add_statement('local.get', '${}'.format(inp.name)) wgn.add_statement('local.get', '${}'.format(inp.variable.name))
return return
if isinstance(inp, ourlang.BinaryOp): if isinstance(inp, ourlang.BinaryOp):
@ -450,7 +442,7 @@ def function_argument(inp: ourlang.FunctionParam) -> wasm.Param:
""" """
Compile: function argument Compile: function argument
""" """
return (inp[0], type_(inp[1]), ) return (inp.name, type_(inp.type), )
def import_(inp: ourlang.Function) -> wasm.Import: def import_(inp: ourlang.Function) -> wasm.Import:
""" """

View File

@ -6,3 +6,6 @@ class StaticError(Exception):
""" """
An error found during static analysis An error found during static analysis
""" """
class TypingError(Exception):
pass

View File

@ -21,18 +21,22 @@ from .typing import (
TypeTuple, TypeTupleMember, TypeTuple, TypeTupleMember,
TypeStaticArray, TypeStaticArrayMember, TypeStaticArray, TypeStaticArrayMember,
TypeStruct, TypeStructMember, TypeStruct, TypeStructMember,
TypeVar,
) )
class Expression: class Expression:
""" """
An expression within a statement An expression within a statement
""" """
__slots__ = ('type', ) __slots__ = ('type', 'type_var', )
type: TypeBase type: TypeBase
type_var: Optional[TypeVar]
def __init__(self, type_: TypeBase) -> None: def __init__(self, type_: TypeBase) -> None:
self.type = type_ self.type = type_
self.type_var = None
class Constant(Expression): class Constant(Expression):
""" """
@ -40,88 +44,15 @@ class Constant(Expression):
""" """
__slots__ = () __slots__ = ()
class ConstantUInt8(Constant): class ConstantPrimitive(Constant):
""" """
An UInt8 constant value expression within a statement An primitive constant value expression within a statement
""" """
__slots__ = ('value', ) __slots__ = ('value', )
value: int value: Union[int, float]
def __init__(self, type_: TypeUInt8, value: int) -> None: def __init__(self, value: Union[int, float]) -> None:
super().__init__(type_)
self.value = value
class ConstantUInt32(Constant):
"""
An UInt32 constant value expression within a statement
"""
__slots__ = ('value', )
value: int
def __init__(self, type_: TypeUInt32, value: int) -> None:
super().__init__(type_)
self.value = value
class ConstantUInt64(Constant):
"""
An UInt64 constant value expression within a statement
"""
__slots__ = ('value', )
value: int
def __init__(self, type_: TypeUInt64, value: int) -> None:
super().__init__(type_)
self.value = value
class ConstantInt32(Constant):
"""
An Int32 constant value expression within a statement
"""
__slots__ = ('value', )
value: int
def __init__(self, type_: TypeInt32, value: int) -> None:
super().__init__(type_)
self.value = value
class ConstantInt64(Constant):
"""
An Int64 constant value expression within a statement
"""
__slots__ = ('value', )
value: int
def __init__(self, type_: TypeInt64, value: int) -> None:
super().__init__(type_)
self.value = value
class ConstantFloat32(Constant):
"""
An Float32 constant value expression within a statement
"""
__slots__ = ('value', )
value: float
def __init__(self, type_: TypeFloat32, value: float) -> None:
super().__init__(type_)
self.value = value
class ConstantFloat64(Constant):
"""
An Float64 constant value expression within a statement
"""
__slots__ = ('value', )
value: float
def __init__(self, type_: TypeFloat64, value: float) -> None:
super().__init__(type_)
self.value = value self.value = value
class ConstantTuple(Constant): class ConstantTuple(Constant):
@ -130,9 +61,9 @@ class ConstantTuple(Constant):
""" """
__slots__ = ('value', ) __slots__ = ('value', )
value: List[Constant] value: List[ConstantPrimitive]
def __init__(self, type_: TypeTuple, value: List[Constant]) -> None: def __init__(self, type_: TypeTuple, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples?
super().__init__(type_) super().__init__(type_)
self.value = value self.value = value
@ -142,9 +73,9 @@ class ConstantStaticArray(Constant):
""" """
__slots__ = ('value', ) __slots__ = ('value', )
value: List[Constant] value: List[ConstantPrimitive]
def __init__(self, type_: TypeStaticArray, value: List[Constant]) -> None: def __init__(self, type_: TypeStaticArray, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays?
super().__init__(type_) super().__init__(type_)
self.value = value self.value = value
@ -152,13 +83,13 @@ class VariableReference(Expression):
""" """
An variable reference expression within a statement An variable reference expression within a statement
""" """
__slots__ = ('name', ) __slots__ = ('variable', )
name: str variable: 'FunctionParam' # also possibly local
def __init__(self, type_: TypeBase, name: str) -> None: def __init__(self, type_: TypeBase, variable: 'FunctionParam') -> None:
super().__init__(type_) super().__init__(type_)
self.name = name self.variable = variable
class UnaryOp(Expression): class UnaryOp(Expression):
""" """
@ -348,13 +279,23 @@ class StatementIf(Statement):
self.statements = [] self.statements = []
self.else_statements = [] self.else_statements = []
FunctionParam = Tuple[str, TypeBase] class FunctionParam:
__slots__ = ('name', 'type', 'type_var', )
name: str
type: TypeBase
type_var: Optional[TypeVar]
def __init__(self, name: str, type_: TypeBase) -> None:
self.name = name
self.type = type_
self.type_var = None
class Function: class Function:
""" """
A function processes input and produces output A function processes input and produces output
""" """
__slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'posonlyargs', ) __slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'returns_type_var', 'posonlyargs', )
name: str name: str
lineno: int lineno: int
@ -362,6 +303,7 @@ class Function:
imported: bool imported: bool
statements: List[Statement] statements: List[Statement]
returns: TypeBase returns: TypeBase
returns_type_var: Optional[TypeVar]
posonlyargs: List[FunctionParam] posonlyargs: List[FunctionParam]
def __init__(self, name: str, lineno: int) -> None: def __init__(self, name: str, lineno: int) -> None:
@ -371,6 +313,7 @@ class Function:
self.imported = False self.imported = False
self.statements = [] self.statements = []
self.returns = TypeNone() self.returns = TypeNone()
self.returns_type_var = None
self.posonlyargs = [] self.posonlyargs = []
class StructConstructor(Function): class StructConstructor(Function):
@ -390,7 +333,7 @@ class StructConstructor(Function):
self.returns = struct self.returns = struct
for mem in struct.members: for mem in struct.members:
self.posonlyargs.append((mem.name, mem.type, )) self.posonlyargs.append(FunctionParam(mem.name, mem.type, ))
self.struct = struct self.struct = struct
@ -410,7 +353,7 @@ class TupleConstructor(Function):
self.returns = tuple_ self.returns = tuple_
for mem in tuple_.members: for mem in tuple_.members:
self.posonlyargs.append((f'arg{mem.idx}', mem.type, )) self.posonlyargs.append(FunctionParam(f'arg{mem.idx}', mem.type, ))
self.tuple = tuple_ self.tuple = tuple_
@ -439,10 +382,10 @@ class ModuleDataBlock:
""" """
__slots__ = ('data', 'address', ) __slots__ = ('data', 'address', )
data: List[Constant] data: List[ConstantPrimitive]
address: Optional[int] address: Optional[int]
def __init__(self, data: List[Constant]) -> None: def __init__(self, data: List[ConstantPrimitive]) -> None:
self.data = data self.data = data
self.address = None self.address = None

View File

@ -35,9 +35,7 @@ from .ourlang import (
AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember, AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember,
BinaryOp, BinaryOp,
Constant, Constant,
ConstantFloat32, ConstantFloat64, ConstantInt32, ConstantInt64, ConstantPrimitive, ConstantTuple, ConstantStaticArray,
ConstantUInt8, ConstantUInt32, ConstantUInt64,
ConstantTuple, ConstantStaticArray,
FunctionCall, FunctionCall,
StructConstructor, TupleConstructor, StructConstructor, TupleConstructor,
@ -48,6 +46,7 @@ from .ourlang import (
Statement, Statement,
StatementIf, StatementPass, StatementReturn, StatementIf, StatementPass, StatementReturn,
FunctionParam,
ModuleConstantDef, ModuleConstantDef,
) )
@ -60,7 +59,7 @@ def phasm_parse(source: str) -> Module:
our_visitor = OurVisitor() our_visitor = OurVisitor()
return our_visitor.visit_Module(res) return our_visitor.visit_Module(res)
OurLocals = Dict[str, TypeBase] OurLocals = Dict[str, Union[FunctionParam]] # Also local variable and module constants?
class OurVisitor: class OurVisitor:
""" """
@ -141,7 +140,7 @@ class OurVisitor:
if not arg.annotation: if not arg.annotation:
_raise_static_error(node, 'Type is required') _raise_static_error(node, 'Type is required')
function.posonlyargs.append(( function.posonlyargs.append(FunctionParam(
arg.arg, arg.arg,
self.visit_type(module, arg.annotation), self.visit_type(module, arg.annotation),
)) ))
@ -210,18 +209,14 @@ class OurVisitor:
exp_type = self.visit_type(module, node.annotation) exp_type = self.visit_type(module, node.annotation)
if isinstance(exp_type, TypeInt32): if isinstance(node.value, ast.Constant):
if not isinstance(node.value, ast.Constant): return ModuleConstantDef(
_raise_static_error(node, 'Must be constant')
constant = ModuleConstantDef(
node.target.id, node.target.id,
node.lineno, node.lineno,
exp_type, exp_type,
self.visit_Module_Constant(module, exp_type, node.value), self.visit_Module_Constant(module, node.value),
None, None,
) )
return constant
if isinstance(exp_type, TypeTuple): if isinstance(exp_type, TypeTuple):
if not isinstance(node.value, ast.Tuple): if not isinstance(node.value, ast.Tuple):
@ -231,7 +226,7 @@ class OurVisitor:
_raise_static_error(node, 'Invalid number of tuple values') _raise_static_error(node, 'Invalid number of tuple values')
tuple_data = [ tuple_data = [
self.visit_Module_Constant(module, mem.type, arg_node) self.visit_Module_Constant(module, arg_node)
for arg_node, mem in zip(node.value.elts, exp_type.members) for arg_node, mem in zip(node.value.elts, exp_type.members)
if isinstance(arg_node, ast.Constant) if isinstance(arg_node, ast.Constant)
] ]
@ -259,7 +254,7 @@ class OurVisitor:
_raise_static_error(node, 'Invalid number of static array values') _raise_static_error(node, 'Invalid number of static array values')
static_array_data = [ static_array_data = [
self.visit_Module_Constant(module, exp_type.member_type, arg_node) self.visit_Module_Constant(module, arg_node)
for arg_node in node.value.elts for arg_node in node.value.elts
if isinstance(arg_node, ast.Constant) if isinstance(arg_node, ast.Constant)
] ]
@ -297,7 +292,10 @@ class OurVisitor:
def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None: def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None:
function = module.functions[node.name] function = module.functions[node.name]
our_locals = dict(function.posonlyargs) our_locals: OurLocals = {
x.name: x
for x in function.posonlyargs
}
for stmt in node.body: for stmt in node.body:
function.statements.append( function.statements.append(
@ -409,7 +407,7 @@ class OurVisitor:
if isinstance(node, ast.Constant): if isinstance(node, ast.Constant):
return self.visit_Module_Constant( return self.visit_Module_Constant(
module, exp_type, node, module, node,
) )
if isinstance(node, ast.Attribute): if isinstance(node, ast.Attribute):
@ -427,11 +425,11 @@ class OurVisitor:
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
if node.id in our_locals: if node.id in our_locals:
act_type = our_locals[node.id] param = our_locals[node.id]
if exp_type != act_type: if exp_type != param.type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(act_type)}') _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(param.type)}')
return VariableReference(act_type, node.id) return VariableReference(param.type, param)
if node.id in module.constant_defs: if node.id in module.constant_defs:
cdef = module.constant_defs[node.id] cdef = module.constant_defs[node.id]
@ -541,10 +539,10 @@ class OurVisitor:
if exp_type.__class__ != func.returns.__class__: if exp_type.__class__ != func.returns.__class__:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
if func.returns.__class__ != func.posonlyargs[0][1].__class__: if func.returns.__class__ != func.posonlyargs[0].type.__class__:
_raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0][1])}') _raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0].type)}')
if module.types['u8'].__class__ != func.posonlyargs[1][1].__class__: if module.types['u8'].__class__ != func.posonlyargs[1].type.__class__:
_raise_static_error(node, 'Only folding over bytes (u8) is supported at this time') _raise_static_error(node, 'Only folding over bytes (u8) is supported at this time')
return Fold( return Fold(
@ -560,16 +558,16 @@ class OurVisitor:
func = module.functions[node.func.id] func = module.functions[node.func.id]
if func.returns != exp_type: # if func.returns != exp_type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') # _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
if len(func.posonlyargs) != len(node.args): if len(func.posonlyargs) != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given') _raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given')
result = FunctionCall(func) result = FunctionCall(func)
result.arguments.extend( result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_type, arg_expr) self.visit_Module_FunctionDef_expr(module, function, our_locals, param.type, arg_expr)
for arg_expr, (_, arg_type) in zip(node.args, func.posonlyargs) for arg_expr, param in zip(node.args, func.posonlyargs)
) )
return result return result
@ -586,7 +584,9 @@ class OurVisitor:
if not node.value.id in our_locals: if not node.value.id in our_locals:
_raise_static_error(node, f'Undefined variable {node.value.id}') _raise_static_error(node, f'Undefined variable {node.value.id}')
node_typ = our_locals[node.value.id] param = our_locals[node.value.id]
node_typ = param.type
if not isinstance(node_typ, TypeStruct): if not isinstance(node_typ, TypeStruct):
_raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
@ -598,7 +598,7 @@ class OurVisitor:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}') _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}')
return AccessStructMember( return AccessStructMember(
VariableReference(node_typ, node.value.id), VariableReference(node_typ, param),
member, member,
) )
@ -614,8 +614,9 @@ class OurVisitor:
varref: Union[ModuleConstantReference, VariableReference] varref: Union[ModuleConstantReference, VariableReference]
if node.value.id in our_locals: if node.value.id in our_locals:
node_typ = our_locals[node.value.id] param = our_locals[node.value.id]
varref = VariableReference(node_typ, node.value.id) node_typ = param.type
varref = VariableReference(param.type, param)
elif node.value.id in module.constant_defs: elif node.value.id in module.constant_defs:
constant_def = module.constant_defs[node.value.id] constant_def = module.constant_defs[node.value.id]
node_typ = constant_def.type node_typ = constant_def.type
@ -642,12 +643,15 @@ class OurVisitor:
) )
if isinstance(node_typ, TypeTuple): if isinstance(node_typ, TypeTuple):
if not isinstance(slice_expr, ConstantUInt32): if not isinstance(slice_expr, ConstantPrimitive):
_raise_static_error(node, 'Must subscript using a constant index') _raise_static_error(node, 'Must subscript using a constant index')
idx = slice_expr.value idx = slice_expr.value
if len(node_typ.members) <= idx: if not isinstance(idx, int):
_raise_static_error(node, 'Must subscript using a constant integer index')
if not (0 <= idx < len(node_typ.members)):
_raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}')
tuple_member = node_typ.members[idx] tuple_member = node_typ.members[idx]
@ -666,7 +670,7 @@ class OurVisitor:
if exp_type != node_typ.member_type: if exp_type != node_typ.member_type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}') _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}')
if not isinstance(slice_expr, ConstantInt32): if not isinstance(slice_expr, ConstantPrimitive):
return AccessStaticArrayMember( return AccessStaticArrayMember(
varref, varref,
node_typ, node_typ,
@ -675,7 +679,10 @@ class OurVisitor:
idx = slice_expr.value idx = slice_expr.value
if len(node_typ.members) <= idx: if not isinstance(idx, int):
_raise_static_error(node, 'Must subscript using an integer index')
if not (0 <= idx < len(node_typ.members)):
_raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}') _raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}')
static_array_member = node_typ.members[idx] static_array_member = node_typ.members[idx]
@ -688,73 +695,15 @@ class OurVisitor:
_raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}') _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}')
def visit_Module_Constant(self, module: Module, exp_type: TypeBase, node: ast.Constant) -> Constant: def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive:
del module del module
_not_implemented(node.kind is None, 'Constant.kind') _not_implemented(node.kind is None, 'Constant.kind')
if isinstance(exp_type, TypeUInt8): if isinstance(node.value, (int, float, )):
if not isinstance(node.value, int): return ConstantPrimitive(node.value)
_raise_static_error(node, 'Expected integer value')
if node.value < 0 or node.value > 255: raise NotImplementedError(f'{node.value} as constant')
_raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}')
return ConstantUInt8(exp_type, node.value)
if isinstance(exp_type, TypeUInt32):
if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value')
if node.value < 0 or node.value > 4294967295:
_raise_static_error(node, 'Integer value out of range')
return ConstantUInt32(exp_type, node.value)
if isinstance(exp_type, TypeUInt64):
if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value')
if node.value < 0 or node.value > 18446744073709551615:
_raise_static_error(node, 'Integer value out of range')
return ConstantUInt64(exp_type, node.value)
if isinstance(exp_type, TypeInt32):
if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value')
if node.value < -2147483648 or node.value > 2147483647:
_raise_static_error(node, 'Integer value out of range')
return ConstantInt32(exp_type, node.value)
if isinstance(exp_type, TypeInt64):
if not isinstance(node.value, int):
_raise_static_error(node, 'Expected integer value')
if node.value < -9223372036854775808 or node.value > 9223372036854775807:
_raise_static_error(node, 'Integer value out of range')
return ConstantInt64(exp_type, node.value)
if isinstance(exp_type, TypeFloat32):
if not isinstance(node.value, (float, int, )):
_raise_static_error(node, 'Expected float value')
# FIXME: Range check
return ConstantFloat32(exp_type, node.value)
if isinstance(exp_type, TypeFloat64):
if not isinstance(node.value, (float, int, )):
_raise_static_error(node, 'Expected float value')
# FIXME: Range check
return ConstantFloat64(exp_type, node.value)
raise NotImplementedError(f'{node} as const for type {exp_type}')
def visit_type(self, module: Module, node: ast.expr) -> TypeBase: def visit_type(self, module: Module, node: ast.expr) -> TypeBase:
if isinstance(node, ast.Constant): if isinstance(node, ast.Constant):

115
phasm/typer.py Normal file
View File

@ -0,0 +1,115 @@
"""
Type checks and enriches the given ast
"""
from . import ourlang
from .typing import Context, TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeVar
def phasm_type(inp: ourlang.Module) -> None:
module(inp)
def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar':
if isinstance(inp, ourlang.ConstantPrimitive):
result = ctx.new_var()
if not isinstance(inp.value, int):
raise NotImplementedError('Float constants in new type system')
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
# Need at least this many bits to store this constant value
result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2))
# Don't dictate anything about signedness - you can use a signed
# constant in an unsigned variable if the bits fit
result.add_constraint(TypeConstraintSigned(None))
result.add_location(str(inp.value))
inp.type_var = result
return result
raise NotImplementedError(constant, inp)
def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
if isinstance(inp, ourlang.Constant):
return constant(ctx, inp)
if isinstance(inp, ourlang.VariableReference):
assert inp.variable.type_var is not None, inp
return inp.variable.type_var
if isinstance(inp, ourlang.BinaryOp):
if inp.operator not in ('+', '-', '|', '&', '^'):
raise NotImplementedError(expression, inp, inp.operator)
left = expression(ctx, inp.left)
right = expression(ctx, inp.right)
ctx.unify(left, right)
return left
if isinstance(inp, ourlang.FunctionCall):
assert inp.function.returns_type_var is not None
if inp.function.posonlyargs:
raise NotImplementedError
return inp.function.returns_type_var
raise NotImplementedError(expression, inp)
def function(ctx: 'Context', inp: ourlang.Function) -> None:
if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn):
raise NotImplementedError('Functions with not just a return statement')
typ = expression(ctx, inp.statements[0].value)
assert inp.returns_type_var is not None
ctx.unify(inp.returns_type_var, typ)
return
def module(inp: ourlang.Module) -> None:
ctx = Context()
for func in inp.functions.values():
func.returns_type_var = _convert_old_type(ctx, func.returns, f'{func.name}.(returns)')
for param in func.posonlyargs:
param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}')
for func in inp.functions.values():
function(ctx, func)
from . import typing
def _convert_old_type(ctx: Context, inp: typing.TypeBase, location: str) -> TypeVar:
result = ctx.new_var()
if isinstance(inp, typing.TypeUInt8):
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if isinstance(inp, typing.TypeUInt32):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if isinstance(inp, typing.TypeUInt64):
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
if isinstance(inp, typing.TypeInt32):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
if isinstance(inp, typing.TypeInt64):
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
raise NotImplementedError(_convert_old_type, inp)

View File

@ -1,7 +1,11 @@
""" """
The phasm type system The phasm type system
""" """
from typing import Optional, List from typing import Dict, Optional, List, Type
import enum
from .exceptions import TypingError
class TypeBase: class TypeBase:
""" """
@ -200,3 +204,176 @@ class TypeStruct(TypeBase):
x.type.alloc_size() x.type.alloc_size()
for x in self.members for x in self.members
) )
## NEW STUFF BELOW
class TypingNarrowProtoError(TypingError):
pass
class TypingNarrowError(TypingError):
def __init__(self, l: 'TypeVar', r: 'TypeVar', msg: str) -> None:
super().__init__(
f'Cannot narrow types {l} and {r}: {msg}'
)
class TypeConstraintBase:
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase':
raise NotImplementedError('narrow', self, other)
class TypeConstraintPrimitive(TypeConstraintBase):
__slots__ = ('primitive', )
class Primitive(enum.Enum):
INT = 0
FLOAT = 1
primitive: Primitive
def __init__(self, primitive: Primitive) -> None:
self.primitive = primitive
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive':
if not isinstance(other, TypeConstraintPrimitive):
raise Exception('Invalid comparison')
if self.primitive != other.primitive:
raise TypingNarrowProtoError('Primitive does not match')
return TypeConstraintPrimitive(self.primitive)
def __repr__(self) -> str:
return f'Primitive={self.primitive.name}'
class TypeConstraintSigned(TypeConstraintBase):
__slots__ = ('signed', )
signed: Optional[bool]
def __init__(self, signed: Optional[bool]) -> None:
self.signed = signed
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintSigned':
if not isinstance(other, TypeConstraintSigned):
raise Exception('Invalid comparison')
if other.signed is None:
return TypeConstraintSigned(self.signed)
if self.signed is None:
return TypeConstraintSigned(other.signed)
if self.signed is not other.signed:
raise TypingNarrowProtoError('Signed does not match')
return TypeConstraintSigned(self.signed)
def __repr__(self) -> str:
return f'Signed={self.signed}'
class TypeConstraintBitWidth(TypeConstraintBase):
__slots__ = ('minb', 'maxb', )
minb: int
maxb: int
def __init__(self, *, minb: int = 1, maxb: int = 64) -> None:
assert minb is not None or maxb is not None
assert maxb <= 64 # For now, support up to 64 bits values
self.minb = minb
self.maxb = maxb
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth':
if not isinstance(other, TypeConstraintBitWidth):
raise Exception('Invalid comparison')
if self.minb > other.maxb:
raise TypingNarrowProtoError('Min bitwidth exceeds other max bitwidth')
if other.minb > self.maxb:
raise TypingNarrowProtoError('Other min bitwidth exceeds max bitwidth')
return TypeConstraintBitWidth(
minb=max(self.minb, other.minb),
maxb=min(self.maxb, other.maxb),
)
def __repr__(self) -> str:
return f'BitWidth={self.minb}..{self.maxb}'
class TypeVar:
def __init__(self, ctx: 'Context') -> None:
self.context = ctx
self.constraints: Dict[Type[TypeConstraintBase], TypeConstraintBase] = {}
self.locations: List[str] = []
def add_constraint(self, newconst: TypeConstraintBase) -> None:
if newconst.__class__ in self.constraints:
self.constraints[newconst.__class__] = self.constraints[newconst.__class__].narrow(newconst)
else:
self.constraints[newconst.__class__] = newconst
def add_location(self, ref: str) -> None:
self.locations.append(ref)
def __repr__(self) -> str:
return (
'TypeVar<'
+ '; '.join(map(repr, self.constraints.values()))
+ '; locations: '
+ ', '.join(self.locations)
+ '>'
)
class Context:
def new_var(self) -> TypeVar:
return TypeVar(self)
def unify(self, l: 'TypeVar', r: 'TypeVar') -> None:
newtypevar = self.new_var()
try:
for const in l.constraints.values():
newtypevar.add_constraint(const)
for const in r.constraints.values():
newtypevar.add_constraint(const)
except TypingNarrowProtoError as ex:
raise TypingNarrowError(l, r, str(ex)) from None
newtypevar.locations.extend(l.locations)
newtypevar.locations.extend(r.locations)
# Make pointer locations to the constraints and locations
# so they get linked together throughout the unification
l.constraints = newtypevar.constraints
l.locations = newtypevar.locations
r.constraints = newtypevar.constraints
r.locations = newtypevar.locations
return
def simplify(inp: TypeVar) -> Optional[str]:
tc_prim = inp.constraints.get(TypeConstraintPrimitive)
tc_bits = inp.constraints.get(TypeConstraintBitWidth)
tc_sign = inp.constraints.get(TypeConstraintSigned)
if tc_prim is None:
return None
assert isinstance(tc_prim, TypeConstraintPrimitive) # type hint
primitive = tc_prim.primitive
if primitive is TypeConstraintPrimitive.Primitive.INT:
if tc_bits is None or tc_sign is None:
return None
assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint
assert isinstance(tc_sign, TypeConstraintSigned) # type hint
if tc_sign.signed is None or tc_bits.minb != tc_bits.maxb or tc_bits.minb not in (8, 32, 64):
return None
base = 'i' if tc_sign.signed else 'u'
return f'{base}{tc_bits.minb}'
return None

View File

@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Iterable, Optional, TextIO
import ctypes import ctypes
import io import io
import warnings
import pywasm.binary import pywasm.binary
import wasm3 import wasm3
@ -13,6 +14,7 @@ import wasmtime
from phasm.compiler import phasm_compile from phasm.compiler import phasm_compile
from phasm.parser import phasm_parse from phasm.parser import phasm_parse
from phasm.typer import phasm_type
from phasm import ourlang from phasm import ourlang
from phasm import wasm from phasm import wasm
@ -40,6 +42,10 @@ class RunnerBase:
Parses the Phasm code into an AST Parses the Phasm code into an AST
""" """
self.phasm_ast = phasm_parse(self.phasm_code) self.phasm_ast = phasm_parse(self.phasm_code)
try:
phasm_type(self.phasm_ast)
except NotImplementedError as exc:
warnings.warn(f'phasm_type throws an NotImplementedError on this test: {exc}')
def compile_ast(self) -> None: def compile_ast(self) -> None:
""" """

View File

@ -304,6 +304,21 @@ def testEntry(a: i32, b: i32) -> i32:
assert 1 == suite.run_code(10, 20).returned_value assert 1 == suite.run_code(10, 20).returned_value
assert 0 == suite.run_code(10, 10).returned_value assert 0 == suite.run_code(10, 10).returned_value
@pytest.mark.integration_test
def test_call_no_args():
code_py = """
def helper() -> i32:
return 19
@exported
def testEntry() -> i32:
return helper()
"""
result = Suite(code_py).run_code()
assert 19 == result.returned_value
@pytest.mark.integration_test @pytest.mark.integration_test
def test_call_pre_defined(): def test_call_pre_defined():
code_py = """ code_py = """

View File

@ -0,0 +1,31 @@
import pytest
from phasm.parser import phasm_parse
from phasm.typer import phasm_type
from phasm.exceptions import TypingError
@pytest.mark.integration_test
def test_constant_too_wide():
code_py = """
def func_const() -> u8:
return 0xFFF
"""
ast = phasm_parse(code_py)
with pytest.raises(TypingError, match='Other min bitwidth exceeds max bitwidth'):
phasm_type(ast)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', [32, 64])
def test_signed_mismatch(type_):
code_py = f"""
def func_const() -> u{type_}:
return 0
def func_call() -> i{type_}:
return func_const()
"""
ast = phasm_parse(code_py)
with pytest.raises(TypingError, match='Signed does not match'):
phasm_type(ast)