Compare commits
5 Commits
master
...
6f3d9a5bcc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f3d9a5bcc | ||
|
|
2d0daf4b90 | ||
|
|
7669f3cbca | ||
|
|
48e16c38b9 | ||
|
|
7acb2bd8e6 |
@ -86,14 +86,8 @@ def expression(inp: ourlang.Expression) -> str:
|
|||||||
"""
|
"""
|
||||||
Render: A Phasm expression
|
Render: A Phasm expression
|
||||||
"""
|
"""
|
||||||
if isinstance(inp, (
|
if isinstance(inp, ourlang.ConstantPrimitive):
|
||||||
ourlang.ConstantUInt8, ourlang.ConstantUInt32, ourlang.ConstantUInt64,
|
# Floats might not round trip if the original constant
|
||||||
ourlang.ConstantInt32, ourlang.ConstantInt64,
|
|
||||||
)):
|
|
||||||
return str(inp.value)
|
|
||||||
|
|
||||||
if isinstance(inp, (ourlang.ConstantFloat32, ourlang.ConstantFloat64, )):
|
|
||||||
# These might not round trip if the original constant
|
|
||||||
# could not fit in the given float type
|
# could not fit in the given float type
|
||||||
return str(inp.value)
|
return str(inp.value)
|
||||||
|
|
||||||
@ -104,7 +98,7 @@ def expression(inp: ourlang.Expression) -> str:
|
|||||||
) + ', )'
|
) + ', )'
|
||||||
|
|
||||||
if isinstance(inp, ourlang.VariableReference):
|
if isinstance(inp, ourlang.VariableReference):
|
||||||
return str(inp.name)
|
return str(inp.variable.name)
|
||||||
|
|
||||||
if isinstance(inp, ourlang.UnaryOp):
|
if isinstance(inp, ourlang.UnaryOp):
|
||||||
if (
|
if (
|
||||||
@ -193,8 +187,8 @@ def function(inp: ourlang.Function) -> str:
|
|||||||
result += '@imported\n'
|
result += '@imported\n'
|
||||||
|
|
||||||
args = ', '.join(
|
args = ', '.join(
|
||||||
f'{x}: {type_(y)}'
|
f'{p.name}: {type_(p.type)}'
|
||||||
for x, y in inp.posonlyargs
|
for p in inp.posonlyargs
|
||||||
)
|
)
|
||||||
|
|
||||||
result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n'
|
result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n'
|
||||||
|
|||||||
@ -131,36 +131,28 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
|||||||
"""
|
"""
|
||||||
Compile: Any expression
|
Compile: Any expression
|
||||||
"""
|
"""
|
||||||
if isinstance(inp, ourlang.ConstantUInt8):
|
if isinstance(inp, ourlang.ConstantPrimitive):
|
||||||
wgn.i32.const(inp.value)
|
stp = typing.simplify(inp.type_var)
|
||||||
return
|
if stp is None:
|
||||||
|
raise NotImplementedError(f'Constants with type {inp.type_var}')
|
||||||
|
|
||||||
if isinstance(inp, ourlang.ConstantUInt32):
|
if stp == 'u8':
|
||||||
wgn.i32.const(inp.value)
|
# No native u8 type - treat as i32, with caution
|
||||||
return
|
wgn.i32.const(inp.value)
|
||||||
|
return
|
||||||
|
|
||||||
if isinstance(inp, ourlang.ConstantUInt64):
|
if stp in ('i32', 'u32'):
|
||||||
wgn.i64.const(inp.value)
|
wgn.i32.const(inp.value)
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(inp, ourlang.ConstantInt32):
|
if stp in ('i64', 'u64'):
|
||||||
wgn.i32.const(inp.value)
|
wgn.i64.const(inp.value)
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(inp, ourlang.ConstantInt64):
|
raise NotImplementedError(f'Constants with type {stp}')
|
||||||
wgn.i64.const(inp.value)
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(inp, ourlang.ConstantFloat32):
|
|
||||||
wgn.f32.const(inp.value)
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(inp, ourlang.ConstantFloat64):
|
|
||||||
wgn.f64.const(inp.value)
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(inp, ourlang.VariableReference):
|
if isinstance(inp, ourlang.VariableReference):
|
||||||
wgn.add_statement('local.get', '${}'.format(inp.name))
|
wgn.add_statement('local.get', '${}'.format(inp.variable.name))
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(inp, ourlang.BinaryOp):
|
if isinstance(inp, ourlang.BinaryOp):
|
||||||
@ -450,7 +442,7 @@ def function_argument(inp: ourlang.FunctionParam) -> wasm.Param:
|
|||||||
"""
|
"""
|
||||||
Compile: function argument
|
Compile: function argument
|
||||||
"""
|
"""
|
||||||
return (inp[0], type_(inp[1]), )
|
return (inp.name, type_(inp.type), )
|
||||||
|
|
||||||
def import_(inp: ourlang.Function) -> wasm.Import:
|
def import_(inp: ourlang.Function) -> wasm.Import:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -6,3 +6,6 @@ class StaticError(Exception):
|
|||||||
"""
|
"""
|
||||||
An error found during static analysis
|
An error found during static analysis
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class TypingError(Exception):
|
||||||
|
pass
|
||||||
|
|||||||
127
phasm/ourlang.py
127
phasm/ourlang.py
@ -21,18 +21,22 @@ from .typing import (
|
|||||||
TypeTuple, TypeTupleMember,
|
TypeTuple, TypeTupleMember,
|
||||||
TypeStaticArray, TypeStaticArrayMember,
|
TypeStaticArray, TypeStaticArrayMember,
|
||||||
TypeStruct, TypeStructMember,
|
TypeStruct, TypeStructMember,
|
||||||
|
|
||||||
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
|
||||||
class Expression:
|
class Expression:
|
||||||
"""
|
"""
|
||||||
An expression within a statement
|
An expression within a statement
|
||||||
"""
|
"""
|
||||||
__slots__ = ('type', )
|
__slots__ = ('type', 'type_var', )
|
||||||
|
|
||||||
type: TypeBase
|
type: TypeBase
|
||||||
|
type_var: Optional[TypeVar]
|
||||||
|
|
||||||
def __init__(self, type_: TypeBase) -> None:
|
def __init__(self, type_: TypeBase) -> None:
|
||||||
self.type = type_
|
self.type = type_
|
||||||
|
self.type_var = None
|
||||||
|
|
||||||
class Constant(Expression):
|
class Constant(Expression):
|
||||||
"""
|
"""
|
||||||
@ -40,88 +44,15 @@ class Constant(Expression):
|
|||||||
"""
|
"""
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
class ConstantUInt8(Constant):
|
class ConstantPrimitive(Constant):
|
||||||
"""
|
"""
|
||||||
An UInt8 constant value expression within a statement
|
An primitive constant value expression within a statement
|
||||||
"""
|
"""
|
||||||
__slots__ = ('value', )
|
__slots__ = ('value', )
|
||||||
|
|
||||||
value: int
|
value: Union[int, float]
|
||||||
|
|
||||||
def __init__(self, type_: TypeUInt8, value: int) -> None:
|
def __init__(self, value: Union[int, float]) -> None:
|
||||||
super().__init__(type_)
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
class ConstantUInt32(Constant):
|
|
||||||
"""
|
|
||||||
An UInt32 constant value expression within a statement
|
|
||||||
"""
|
|
||||||
__slots__ = ('value', )
|
|
||||||
|
|
||||||
value: int
|
|
||||||
|
|
||||||
def __init__(self, type_: TypeUInt32, value: int) -> None:
|
|
||||||
super().__init__(type_)
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
class ConstantUInt64(Constant):
|
|
||||||
"""
|
|
||||||
An UInt64 constant value expression within a statement
|
|
||||||
"""
|
|
||||||
__slots__ = ('value', )
|
|
||||||
|
|
||||||
value: int
|
|
||||||
|
|
||||||
def __init__(self, type_: TypeUInt64, value: int) -> None:
|
|
||||||
super().__init__(type_)
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
class ConstantInt32(Constant):
|
|
||||||
"""
|
|
||||||
An Int32 constant value expression within a statement
|
|
||||||
"""
|
|
||||||
__slots__ = ('value', )
|
|
||||||
|
|
||||||
value: int
|
|
||||||
|
|
||||||
def __init__(self, type_: TypeInt32, value: int) -> None:
|
|
||||||
super().__init__(type_)
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
class ConstantInt64(Constant):
|
|
||||||
"""
|
|
||||||
An Int64 constant value expression within a statement
|
|
||||||
"""
|
|
||||||
__slots__ = ('value', )
|
|
||||||
|
|
||||||
value: int
|
|
||||||
|
|
||||||
def __init__(self, type_: TypeInt64, value: int) -> None:
|
|
||||||
super().__init__(type_)
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
class ConstantFloat32(Constant):
|
|
||||||
"""
|
|
||||||
An Float32 constant value expression within a statement
|
|
||||||
"""
|
|
||||||
__slots__ = ('value', )
|
|
||||||
|
|
||||||
value: float
|
|
||||||
|
|
||||||
def __init__(self, type_: TypeFloat32, value: float) -> None:
|
|
||||||
super().__init__(type_)
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
class ConstantFloat64(Constant):
|
|
||||||
"""
|
|
||||||
An Float64 constant value expression within a statement
|
|
||||||
"""
|
|
||||||
__slots__ = ('value', )
|
|
||||||
|
|
||||||
value: float
|
|
||||||
|
|
||||||
def __init__(self, type_: TypeFloat64, value: float) -> None:
|
|
||||||
super().__init__(type_)
|
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
class ConstantTuple(Constant):
|
class ConstantTuple(Constant):
|
||||||
@ -130,9 +61,9 @@ class ConstantTuple(Constant):
|
|||||||
"""
|
"""
|
||||||
__slots__ = ('value', )
|
__slots__ = ('value', )
|
||||||
|
|
||||||
value: List[Constant]
|
value: List[ConstantPrimitive]
|
||||||
|
|
||||||
def __init__(self, type_: TypeTuple, value: List[Constant]) -> None:
|
def __init__(self, type_: TypeTuple, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples?
|
||||||
super().__init__(type_)
|
super().__init__(type_)
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
@ -142,9 +73,9 @@ class ConstantStaticArray(Constant):
|
|||||||
"""
|
"""
|
||||||
__slots__ = ('value', )
|
__slots__ = ('value', )
|
||||||
|
|
||||||
value: List[Constant]
|
value: List[ConstantPrimitive]
|
||||||
|
|
||||||
def __init__(self, type_: TypeStaticArray, value: List[Constant]) -> None:
|
def __init__(self, type_: TypeStaticArray, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays?
|
||||||
super().__init__(type_)
|
super().__init__(type_)
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
@ -152,13 +83,13 @@ class VariableReference(Expression):
|
|||||||
"""
|
"""
|
||||||
An variable reference expression within a statement
|
An variable reference expression within a statement
|
||||||
"""
|
"""
|
||||||
__slots__ = ('name', )
|
__slots__ = ('variable', )
|
||||||
|
|
||||||
name: str
|
variable: 'FunctionParam' # also possibly local
|
||||||
|
|
||||||
def __init__(self, type_: TypeBase, name: str) -> None:
|
def __init__(self, type_: TypeBase, variable: 'FunctionParam') -> None:
|
||||||
super().__init__(type_)
|
super().__init__(type_)
|
||||||
self.name = name
|
self.variable = variable
|
||||||
|
|
||||||
class UnaryOp(Expression):
|
class UnaryOp(Expression):
|
||||||
"""
|
"""
|
||||||
@ -348,13 +279,23 @@ class StatementIf(Statement):
|
|||||||
self.statements = []
|
self.statements = []
|
||||||
self.else_statements = []
|
self.else_statements = []
|
||||||
|
|
||||||
FunctionParam = Tuple[str, TypeBase]
|
class FunctionParam:
|
||||||
|
__slots__ = ('name', 'type', 'type_var', )
|
||||||
|
|
||||||
|
name: str
|
||||||
|
type: TypeBase
|
||||||
|
type_var: Optional[TypeVar]
|
||||||
|
|
||||||
|
def __init__(self, name: str, type_: TypeBase) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.type = type_
|
||||||
|
self.type_var = None
|
||||||
|
|
||||||
class Function:
|
class Function:
|
||||||
"""
|
"""
|
||||||
A function processes input and produces output
|
A function processes input and produces output
|
||||||
"""
|
"""
|
||||||
__slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'posonlyargs', )
|
__slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'returns_type_var', 'posonlyargs', )
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
lineno: int
|
lineno: int
|
||||||
@ -362,6 +303,7 @@ class Function:
|
|||||||
imported: bool
|
imported: bool
|
||||||
statements: List[Statement]
|
statements: List[Statement]
|
||||||
returns: TypeBase
|
returns: TypeBase
|
||||||
|
returns_type_var: Optional[TypeVar]
|
||||||
posonlyargs: List[FunctionParam]
|
posonlyargs: List[FunctionParam]
|
||||||
|
|
||||||
def __init__(self, name: str, lineno: int) -> None:
|
def __init__(self, name: str, lineno: int) -> None:
|
||||||
@ -371,6 +313,7 @@ class Function:
|
|||||||
self.imported = False
|
self.imported = False
|
||||||
self.statements = []
|
self.statements = []
|
||||||
self.returns = TypeNone()
|
self.returns = TypeNone()
|
||||||
|
self.returns_type_var = None
|
||||||
self.posonlyargs = []
|
self.posonlyargs = []
|
||||||
|
|
||||||
class StructConstructor(Function):
|
class StructConstructor(Function):
|
||||||
@ -390,7 +333,7 @@ class StructConstructor(Function):
|
|||||||
self.returns = struct
|
self.returns = struct
|
||||||
|
|
||||||
for mem in struct.members:
|
for mem in struct.members:
|
||||||
self.posonlyargs.append((mem.name, mem.type, ))
|
self.posonlyargs.append(FunctionParam(mem.name, mem.type, ))
|
||||||
|
|
||||||
self.struct = struct
|
self.struct = struct
|
||||||
|
|
||||||
@ -410,7 +353,7 @@ class TupleConstructor(Function):
|
|||||||
self.returns = tuple_
|
self.returns = tuple_
|
||||||
|
|
||||||
for mem in tuple_.members:
|
for mem in tuple_.members:
|
||||||
self.posonlyargs.append((f'arg{mem.idx}', mem.type, ))
|
self.posonlyargs.append(FunctionParam(f'arg{mem.idx}', mem.type, ))
|
||||||
|
|
||||||
self.tuple = tuple_
|
self.tuple = tuple_
|
||||||
|
|
||||||
@ -439,10 +382,10 @@ class ModuleDataBlock:
|
|||||||
"""
|
"""
|
||||||
__slots__ = ('data', 'address', )
|
__slots__ = ('data', 'address', )
|
||||||
|
|
||||||
data: List[Constant]
|
data: List[ConstantPrimitive]
|
||||||
address: Optional[int]
|
address: Optional[int]
|
||||||
|
|
||||||
def __init__(self, data: List[Constant]) -> None:
|
def __init__(self, data: List[ConstantPrimitive]) -> None:
|
||||||
self.data = data
|
self.data = data
|
||||||
self.address = None
|
self.address = None
|
||||||
|
|
||||||
|
|||||||
143
phasm/parser.py
143
phasm/parser.py
@ -35,9 +35,7 @@ from .ourlang import (
|
|||||||
AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember,
|
AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember,
|
||||||
BinaryOp,
|
BinaryOp,
|
||||||
Constant,
|
Constant,
|
||||||
ConstantFloat32, ConstantFloat64, ConstantInt32, ConstantInt64,
|
ConstantPrimitive, ConstantTuple, ConstantStaticArray,
|
||||||
ConstantUInt8, ConstantUInt32, ConstantUInt64,
|
|
||||||
ConstantTuple, ConstantStaticArray,
|
|
||||||
|
|
||||||
FunctionCall,
|
FunctionCall,
|
||||||
StructConstructor, TupleConstructor,
|
StructConstructor, TupleConstructor,
|
||||||
@ -48,6 +46,7 @@ from .ourlang import (
|
|||||||
Statement,
|
Statement,
|
||||||
StatementIf, StatementPass, StatementReturn,
|
StatementIf, StatementPass, StatementReturn,
|
||||||
|
|
||||||
|
FunctionParam,
|
||||||
ModuleConstantDef,
|
ModuleConstantDef,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -60,7 +59,7 @@ def phasm_parse(source: str) -> Module:
|
|||||||
our_visitor = OurVisitor()
|
our_visitor = OurVisitor()
|
||||||
return our_visitor.visit_Module(res)
|
return our_visitor.visit_Module(res)
|
||||||
|
|
||||||
OurLocals = Dict[str, TypeBase]
|
OurLocals = Dict[str, Union[FunctionParam]] # Also local variable and module constants?
|
||||||
|
|
||||||
class OurVisitor:
|
class OurVisitor:
|
||||||
"""
|
"""
|
||||||
@ -141,7 +140,7 @@ class OurVisitor:
|
|||||||
if not arg.annotation:
|
if not arg.annotation:
|
||||||
_raise_static_error(node, 'Type is required')
|
_raise_static_error(node, 'Type is required')
|
||||||
|
|
||||||
function.posonlyargs.append((
|
function.posonlyargs.append(FunctionParam(
|
||||||
arg.arg,
|
arg.arg,
|
||||||
self.visit_type(module, arg.annotation),
|
self.visit_type(module, arg.annotation),
|
||||||
))
|
))
|
||||||
@ -210,18 +209,14 @@ class OurVisitor:
|
|||||||
|
|
||||||
exp_type = self.visit_type(module, node.annotation)
|
exp_type = self.visit_type(module, node.annotation)
|
||||||
|
|
||||||
if isinstance(exp_type, TypeInt32):
|
if isinstance(node.value, ast.Constant):
|
||||||
if not isinstance(node.value, ast.Constant):
|
return ModuleConstantDef(
|
||||||
_raise_static_error(node, 'Must be constant')
|
|
||||||
|
|
||||||
constant = ModuleConstantDef(
|
|
||||||
node.target.id,
|
node.target.id,
|
||||||
node.lineno,
|
node.lineno,
|
||||||
exp_type,
|
exp_type,
|
||||||
self.visit_Module_Constant(module, exp_type, node.value),
|
self.visit_Module_Constant(module, node.value),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
return constant
|
|
||||||
|
|
||||||
if isinstance(exp_type, TypeTuple):
|
if isinstance(exp_type, TypeTuple):
|
||||||
if not isinstance(node.value, ast.Tuple):
|
if not isinstance(node.value, ast.Tuple):
|
||||||
@ -231,7 +226,7 @@ class OurVisitor:
|
|||||||
_raise_static_error(node, 'Invalid number of tuple values')
|
_raise_static_error(node, 'Invalid number of tuple values')
|
||||||
|
|
||||||
tuple_data = [
|
tuple_data = [
|
||||||
self.visit_Module_Constant(module, mem.type, arg_node)
|
self.visit_Module_Constant(module, arg_node)
|
||||||
for arg_node, mem in zip(node.value.elts, exp_type.members)
|
for arg_node, mem in zip(node.value.elts, exp_type.members)
|
||||||
if isinstance(arg_node, ast.Constant)
|
if isinstance(arg_node, ast.Constant)
|
||||||
]
|
]
|
||||||
@ -259,7 +254,7 @@ class OurVisitor:
|
|||||||
_raise_static_error(node, 'Invalid number of static array values')
|
_raise_static_error(node, 'Invalid number of static array values')
|
||||||
|
|
||||||
static_array_data = [
|
static_array_data = [
|
||||||
self.visit_Module_Constant(module, exp_type.member_type, arg_node)
|
self.visit_Module_Constant(module, arg_node)
|
||||||
for arg_node in node.value.elts
|
for arg_node in node.value.elts
|
||||||
if isinstance(arg_node, ast.Constant)
|
if isinstance(arg_node, ast.Constant)
|
||||||
]
|
]
|
||||||
@ -297,7 +292,10 @@ class OurVisitor:
|
|||||||
def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None:
|
def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None:
|
||||||
function = module.functions[node.name]
|
function = module.functions[node.name]
|
||||||
|
|
||||||
our_locals = dict(function.posonlyargs)
|
our_locals: OurLocals = {
|
||||||
|
x.name: x
|
||||||
|
for x in function.posonlyargs
|
||||||
|
}
|
||||||
|
|
||||||
for stmt in node.body:
|
for stmt in node.body:
|
||||||
function.statements.append(
|
function.statements.append(
|
||||||
@ -409,7 +407,7 @@ class OurVisitor:
|
|||||||
|
|
||||||
if isinstance(node, ast.Constant):
|
if isinstance(node, ast.Constant):
|
||||||
return self.visit_Module_Constant(
|
return self.visit_Module_Constant(
|
||||||
module, exp_type, node,
|
module, node,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(node, ast.Attribute):
|
if isinstance(node, ast.Attribute):
|
||||||
@ -427,11 +425,11 @@ class OurVisitor:
|
|||||||
_raise_static_error(node, 'Must be load context')
|
_raise_static_error(node, 'Must be load context')
|
||||||
|
|
||||||
if node.id in our_locals:
|
if node.id in our_locals:
|
||||||
act_type = our_locals[node.id]
|
param = our_locals[node.id]
|
||||||
if exp_type != act_type:
|
if exp_type != param.type:
|
||||||
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(act_type)}')
|
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(param.type)}')
|
||||||
|
|
||||||
return VariableReference(act_type, node.id)
|
return VariableReference(param.type, param)
|
||||||
|
|
||||||
if node.id in module.constant_defs:
|
if node.id in module.constant_defs:
|
||||||
cdef = module.constant_defs[node.id]
|
cdef = module.constant_defs[node.id]
|
||||||
@ -541,10 +539,10 @@ class OurVisitor:
|
|||||||
if exp_type.__class__ != func.returns.__class__:
|
if exp_type.__class__ != func.returns.__class__:
|
||||||
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
|
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
|
||||||
|
|
||||||
if func.returns.__class__ != func.posonlyargs[0][1].__class__:
|
if func.returns.__class__ != func.posonlyargs[0].type.__class__:
|
||||||
_raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0][1])}')
|
_raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0].type)}')
|
||||||
|
|
||||||
if module.types['u8'].__class__ != func.posonlyargs[1][1].__class__:
|
if module.types['u8'].__class__ != func.posonlyargs[1].type.__class__:
|
||||||
_raise_static_error(node, 'Only folding over bytes (u8) is supported at this time')
|
_raise_static_error(node, 'Only folding over bytes (u8) is supported at this time')
|
||||||
|
|
||||||
return Fold(
|
return Fold(
|
||||||
@ -560,16 +558,16 @@ class OurVisitor:
|
|||||||
|
|
||||||
func = module.functions[node.func.id]
|
func = module.functions[node.func.id]
|
||||||
|
|
||||||
if func.returns != exp_type:
|
# if func.returns != exp_type:
|
||||||
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
|
# _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
|
||||||
|
|
||||||
if len(func.posonlyargs) != len(node.args):
|
if len(func.posonlyargs) != len(node.args):
|
||||||
_raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given')
|
_raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given')
|
||||||
|
|
||||||
result = FunctionCall(func)
|
result = FunctionCall(func)
|
||||||
result.arguments.extend(
|
result.arguments.extend(
|
||||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_type, arg_expr)
|
self.visit_Module_FunctionDef_expr(module, function, our_locals, param.type, arg_expr)
|
||||||
for arg_expr, (_, arg_type) in zip(node.args, func.posonlyargs)
|
for arg_expr, param in zip(node.args, func.posonlyargs)
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -586,7 +584,9 @@ class OurVisitor:
|
|||||||
if not node.value.id in our_locals:
|
if not node.value.id in our_locals:
|
||||||
_raise_static_error(node, f'Undefined variable {node.value.id}')
|
_raise_static_error(node, f'Undefined variable {node.value.id}')
|
||||||
|
|
||||||
node_typ = our_locals[node.value.id]
|
param = our_locals[node.value.id]
|
||||||
|
|
||||||
|
node_typ = param.type
|
||||||
if not isinstance(node_typ, TypeStruct):
|
if not isinstance(node_typ, TypeStruct):
|
||||||
_raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
|
_raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
|
||||||
|
|
||||||
@ -598,7 +598,7 @@ class OurVisitor:
|
|||||||
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}')
|
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}')
|
||||||
|
|
||||||
return AccessStructMember(
|
return AccessStructMember(
|
||||||
VariableReference(node_typ, node.value.id),
|
VariableReference(node_typ, param),
|
||||||
member,
|
member,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -614,8 +614,9 @@ class OurVisitor:
|
|||||||
|
|
||||||
varref: Union[ModuleConstantReference, VariableReference]
|
varref: Union[ModuleConstantReference, VariableReference]
|
||||||
if node.value.id in our_locals:
|
if node.value.id in our_locals:
|
||||||
node_typ = our_locals[node.value.id]
|
param = our_locals[node.value.id]
|
||||||
varref = VariableReference(node_typ, node.value.id)
|
node_typ = param.type
|
||||||
|
varref = VariableReference(param.type, param)
|
||||||
elif node.value.id in module.constant_defs:
|
elif node.value.id in module.constant_defs:
|
||||||
constant_def = module.constant_defs[node.value.id]
|
constant_def = module.constant_defs[node.value.id]
|
||||||
node_typ = constant_def.type
|
node_typ = constant_def.type
|
||||||
@ -642,12 +643,15 @@ class OurVisitor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(node_typ, TypeTuple):
|
if isinstance(node_typ, TypeTuple):
|
||||||
if not isinstance(slice_expr, ConstantUInt32):
|
if not isinstance(slice_expr, ConstantPrimitive):
|
||||||
_raise_static_error(node, 'Must subscript using a constant index')
|
_raise_static_error(node, 'Must subscript using a constant index')
|
||||||
|
|
||||||
idx = slice_expr.value
|
idx = slice_expr.value
|
||||||
|
|
||||||
if len(node_typ.members) <= idx:
|
if not isinstance(idx, int):
|
||||||
|
_raise_static_error(node, 'Must subscript using a constant integer index')
|
||||||
|
|
||||||
|
if not (0 <= idx < len(node_typ.members)):
|
||||||
_raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}')
|
_raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}')
|
||||||
|
|
||||||
tuple_member = node_typ.members[idx]
|
tuple_member = node_typ.members[idx]
|
||||||
@ -666,7 +670,7 @@ class OurVisitor:
|
|||||||
if exp_type != node_typ.member_type:
|
if exp_type != node_typ.member_type:
|
||||||
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}')
|
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}')
|
||||||
|
|
||||||
if not isinstance(slice_expr, ConstantInt32):
|
if not isinstance(slice_expr, ConstantPrimitive):
|
||||||
return AccessStaticArrayMember(
|
return AccessStaticArrayMember(
|
||||||
varref,
|
varref,
|
||||||
node_typ,
|
node_typ,
|
||||||
@ -675,7 +679,10 @@ class OurVisitor:
|
|||||||
|
|
||||||
idx = slice_expr.value
|
idx = slice_expr.value
|
||||||
|
|
||||||
if len(node_typ.members) <= idx:
|
if not isinstance(idx, int):
|
||||||
|
_raise_static_error(node, 'Must subscript using an integer index')
|
||||||
|
|
||||||
|
if not (0 <= idx < len(node_typ.members)):
|
||||||
_raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}')
|
_raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}')
|
||||||
|
|
||||||
static_array_member = node_typ.members[idx]
|
static_array_member = node_typ.members[idx]
|
||||||
@ -688,73 +695,15 @@ class OurVisitor:
|
|||||||
|
|
||||||
_raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}')
|
_raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}')
|
||||||
|
|
||||||
def visit_Module_Constant(self, module: Module, exp_type: TypeBase, node: ast.Constant) -> Constant:
|
def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive:
|
||||||
del module
|
del module
|
||||||
|
|
||||||
_not_implemented(node.kind is None, 'Constant.kind')
|
_not_implemented(node.kind is None, 'Constant.kind')
|
||||||
|
|
||||||
if isinstance(exp_type, TypeUInt8):
|
if isinstance(node.value, (int, float, )):
|
||||||
if not isinstance(node.value, int):
|
return ConstantPrimitive(node.value)
|
||||||
_raise_static_error(node, 'Expected integer value')
|
|
||||||
|
|
||||||
if node.value < 0 or node.value > 255:
|
raise NotImplementedError(f'{node.value} as constant')
|
||||||
_raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}')
|
|
||||||
|
|
||||||
return ConstantUInt8(exp_type, node.value)
|
|
||||||
|
|
||||||
if isinstance(exp_type, TypeUInt32):
|
|
||||||
if not isinstance(node.value, int):
|
|
||||||
_raise_static_error(node, 'Expected integer value')
|
|
||||||
|
|
||||||
if node.value < 0 or node.value > 4294967295:
|
|
||||||
_raise_static_error(node, 'Integer value out of range')
|
|
||||||
|
|
||||||
return ConstantUInt32(exp_type, node.value)
|
|
||||||
|
|
||||||
if isinstance(exp_type, TypeUInt64):
|
|
||||||
if not isinstance(node.value, int):
|
|
||||||
_raise_static_error(node, 'Expected integer value')
|
|
||||||
|
|
||||||
if node.value < 0 or node.value > 18446744073709551615:
|
|
||||||
_raise_static_error(node, 'Integer value out of range')
|
|
||||||
|
|
||||||
return ConstantUInt64(exp_type, node.value)
|
|
||||||
|
|
||||||
if isinstance(exp_type, TypeInt32):
|
|
||||||
if not isinstance(node.value, int):
|
|
||||||
_raise_static_error(node, 'Expected integer value')
|
|
||||||
|
|
||||||
if node.value < -2147483648 or node.value > 2147483647:
|
|
||||||
_raise_static_error(node, 'Integer value out of range')
|
|
||||||
|
|
||||||
return ConstantInt32(exp_type, node.value)
|
|
||||||
|
|
||||||
if isinstance(exp_type, TypeInt64):
|
|
||||||
if not isinstance(node.value, int):
|
|
||||||
_raise_static_error(node, 'Expected integer value')
|
|
||||||
|
|
||||||
if node.value < -9223372036854775808 or node.value > 9223372036854775807:
|
|
||||||
_raise_static_error(node, 'Integer value out of range')
|
|
||||||
|
|
||||||
return ConstantInt64(exp_type, node.value)
|
|
||||||
|
|
||||||
if isinstance(exp_type, TypeFloat32):
|
|
||||||
if not isinstance(node.value, (float, int, )):
|
|
||||||
_raise_static_error(node, 'Expected float value')
|
|
||||||
|
|
||||||
# FIXME: Range check
|
|
||||||
|
|
||||||
return ConstantFloat32(exp_type, node.value)
|
|
||||||
|
|
||||||
if isinstance(exp_type, TypeFloat64):
|
|
||||||
if not isinstance(node.value, (float, int, )):
|
|
||||||
_raise_static_error(node, 'Expected float value')
|
|
||||||
|
|
||||||
# FIXME: Range check
|
|
||||||
|
|
||||||
return ConstantFloat64(exp_type, node.value)
|
|
||||||
|
|
||||||
raise NotImplementedError(f'{node} as const for type {exp_type}')
|
|
||||||
|
|
||||||
def visit_type(self, module: Module, node: ast.expr) -> TypeBase:
|
def visit_type(self, module: Module, node: ast.expr) -> TypeBase:
|
||||||
if isinstance(node, ast.Constant):
|
if isinstance(node, ast.Constant):
|
||||||
|
|||||||
115
phasm/typer.py
Normal file
115
phasm/typer.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
"""
|
||||||
|
Type checks and enriches the given ast
|
||||||
|
"""
|
||||||
|
from . import ourlang
|
||||||
|
|
||||||
|
from .typing import Context, TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeVar
|
||||||
|
|
||||||
|
def phasm_type(inp: ourlang.Module) -> None:
|
||||||
|
module(inp)
|
||||||
|
|
||||||
|
def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar':
|
||||||
|
if isinstance(inp, ourlang.ConstantPrimitive):
|
||||||
|
result = ctx.new_var()
|
||||||
|
|
||||||
|
if not isinstance(inp.value, int):
|
||||||
|
raise NotImplementedError('Float constants in new type system')
|
||||||
|
|
||||||
|
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
|
||||||
|
|
||||||
|
# Need at least this many bits to store this constant value
|
||||||
|
result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2))
|
||||||
|
# Don't dictate anything about signedness - you can use a signed
|
||||||
|
# constant in an unsigned variable if the bits fit
|
||||||
|
result.add_constraint(TypeConstraintSigned(None))
|
||||||
|
|
||||||
|
result.add_location(str(inp.value))
|
||||||
|
|
||||||
|
inp.type_var = result
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
raise NotImplementedError(constant, inp)
|
||||||
|
|
||||||
|
def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
|
||||||
|
if isinstance(inp, ourlang.Constant):
|
||||||
|
return constant(ctx, inp)
|
||||||
|
|
||||||
|
if isinstance(inp, ourlang.VariableReference):
|
||||||
|
assert inp.variable.type_var is not None, inp
|
||||||
|
return inp.variable.type_var
|
||||||
|
|
||||||
|
if isinstance(inp, ourlang.BinaryOp):
|
||||||
|
if inp.operator not in ('+', '-', '|', '&', '^'):
|
||||||
|
raise NotImplementedError(expression, inp, inp.operator)
|
||||||
|
|
||||||
|
left = expression(ctx, inp.left)
|
||||||
|
right = expression(ctx, inp.right)
|
||||||
|
ctx.unify(left, right)
|
||||||
|
return left
|
||||||
|
|
||||||
|
if isinstance(inp, ourlang.FunctionCall):
|
||||||
|
assert inp.function.returns_type_var is not None
|
||||||
|
if inp.function.posonlyargs:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return inp.function.returns_type_var
|
||||||
|
|
||||||
|
raise NotImplementedError(expression, inp)
|
||||||
|
|
||||||
|
def function(ctx: 'Context', inp: ourlang.Function) -> None:
|
||||||
|
if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn):
|
||||||
|
raise NotImplementedError('Functions with not just a return statement')
|
||||||
|
typ = expression(ctx, inp.statements[0].value)
|
||||||
|
|
||||||
|
assert inp.returns_type_var is not None
|
||||||
|
ctx.unify(inp.returns_type_var, typ)
|
||||||
|
return
|
||||||
|
|
||||||
|
def module(inp: ourlang.Module) -> None:
|
||||||
|
ctx = Context()
|
||||||
|
|
||||||
|
for func in inp.functions.values():
|
||||||
|
func.returns_type_var = _convert_old_type(ctx, func.returns, f'{func.name}.(returns)')
|
||||||
|
for param in func.posonlyargs:
|
||||||
|
param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}')
|
||||||
|
|
||||||
|
for func in inp.functions.values():
|
||||||
|
function(ctx, func)
|
||||||
|
|
||||||
|
from . import typing
|
||||||
|
|
||||||
|
def _convert_old_type(ctx: Context, inp: typing.TypeBase, location: str) -> TypeVar:
|
||||||
|
result = ctx.new_var()
|
||||||
|
|
||||||
|
if isinstance(inp, typing.TypeUInt8):
|
||||||
|
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
|
||||||
|
result.add_constraint(TypeConstraintSigned(False))
|
||||||
|
result.add_location(location)
|
||||||
|
return result
|
||||||
|
|
||||||
|
if isinstance(inp, typing.TypeUInt32):
|
||||||
|
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
|
||||||
|
result.add_constraint(TypeConstraintSigned(False))
|
||||||
|
result.add_location(location)
|
||||||
|
return result
|
||||||
|
|
||||||
|
if isinstance(inp, typing.TypeUInt64):
|
||||||
|
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
|
||||||
|
result.add_constraint(TypeConstraintSigned(False))
|
||||||
|
result.add_location(location)
|
||||||
|
return result
|
||||||
|
|
||||||
|
if isinstance(inp, typing.TypeInt32):
|
||||||
|
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
|
||||||
|
result.add_constraint(TypeConstraintSigned(True))
|
||||||
|
result.add_location(location)
|
||||||
|
return result
|
||||||
|
|
||||||
|
if isinstance(inp, typing.TypeInt64):
|
||||||
|
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
|
||||||
|
result.add_constraint(TypeConstraintSigned(True))
|
||||||
|
result.add_location(location)
|
||||||
|
return result
|
||||||
|
|
||||||
|
raise NotImplementedError(_convert_old_type, inp)
|
||||||
179
phasm/typing.py
179
phasm/typing.py
@ -1,7 +1,11 @@
|
|||||||
"""
|
"""
|
||||||
The phasm type system
|
The phasm type system
|
||||||
"""
|
"""
|
||||||
from typing import Optional, List
|
from typing import Dict, Optional, List, Type
|
||||||
|
|
||||||
|
import enum
|
||||||
|
|
||||||
|
from .exceptions import TypingError
|
||||||
|
|
||||||
class TypeBase:
|
class TypeBase:
|
||||||
"""
|
"""
|
||||||
@ -200,3 +204,176 @@ class TypeStruct(TypeBase):
|
|||||||
x.type.alloc_size()
|
x.type.alloc_size()
|
||||||
for x in self.members
|
for x in self.members
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## NEW STUFF BELOW
|
||||||
|
|
||||||
|
class TypingNarrowProtoError(TypingError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TypingNarrowError(TypingError):
|
||||||
|
def __init__(self, l: 'TypeVar', r: 'TypeVar', msg: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f'Cannot narrow types {l} and {r}: {msg}'
|
||||||
|
)
|
||||||
|
|
||||||
|
class TypeConstraintBase:
|
||||||
|
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase':
|
||||||
|
raise NotImplementedError('narrow', self, other)
|
||||||
|
|
||||||
|
class TypeConstraintPrimitive(TypeConstraintBase):
|
||||||
|
__slots__ = ('primitive', )
|
||||||
|
|
||||||
|
class Primitive(enum.Enum):
|
||||||
|
INT = 0
|
||||||
|
FLOAT = 1
|
||||||
|
|
||||||
|
primitive: Primitive
|
||||||
|
|
||||||
|
def __init__(self, primitive: Primitive) -> None:
|
||||||
|
self.primitive = primitive
|
||||||
|
|
||||||
|
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive':
|
||||||
|
if not isinstance(other, TypeConstraintPrimitive):
|
||||||
|
raise Exception('Invalid comparison')
|
||||||
|
|
||||||
|
if self.primitive != other.primitive:
|
||||||
|
raise TypingNarrowProtoError('Primitive does not match')
|
||||||
|
|
||||||
|
return TypeConstraintPrimitive(self.primitive)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'Primitive={self.primitive.name}'
|
||||||
|
|
||||||
|
class TypeConstraintSigned(TypeConstraintBase):
|
||||||
|
__slots__ = ('signed', )
|
||||||
|
|
||||||
|
signed: Optional[bool]
|
||||||
|
|
||||||
|
def __init__(self, signed: Optional[bool]) -> None:
|
||||||
|
self.signed = signed
|
||||||
|
|
||||||
|
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintSigned':
|
||||||
|
if not isinstance(other, TypeConstraintSigned):
|
||||||
|
raise Exception('Invalid comparison')
|
||||||
|
|
||||||
|
if other.signed is None:
|
||||||
|
return TypeConstraintSigned(self.signed)
|
||||||
|
if self.signed is None:
|
||||||
|
return TypeConstraintSigned(other.signed)
|
||||||
|
|
||||||
|
if self.signed is not other.signed:
|
||||||
|
raise TypingNarrowProtoError('Signed does not match')
|
||||||
|
|
||||||
|
return TypeConstraintSigned(self.signed)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'Signed={self.signed}'
|
||||||
|
|
||||||
|
class TypeConstraintBitWidth(TypeConstraintBase):
|
||||||
|
__slots__ = ('minb', 'maxb', )
|
||||||
|
|
||||||
|
minb: int
|
||||||
|
maxb: int
|
||||||
|
|
||||||
|
def __init__(self, *, minb: int = 1, maxb: int = 64) -> None:
|
||||||
|
assert minb is not None or maxb is not None
|
||||||
|
assert maxb <= 64 # For now, support up to 64 bits values
|
||||||
|
|
||||||
|
self.minb = minb
|
||||||
|
self.maxb = maxb
|
||||||
|
|
||||||
|
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth':
|
||||||
|
if not isinstance(other, TypeConstraintBitWidth):
|
||||||
|
raise Exception('Invalid comparison')
|
||||||
|
|
||||||
|
if self.minb > other.maxb:
|
||||||
|
raise TypingNarrowProtoError('Min bitwidth exceeds other max bitwidth')
|
||||||
|
|
||||||
|
if other.minb > self.maxb:
|
||||||
|
raise TypingNarrowProtoError('Other min bitwidth exceeds max bitwidth')
|
||||||
|
|
||||||
|
return TypeConstraintBitWidth(
|
||||||
|
minb=max(self.minb, other.minb),
|
||||||
|
maxb=min(self.maxb, other.maxb),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'BitWidth={self.minb}..{self.maxb}'
|
||||||
|
|
||||||
|
class TypeVar:
|
||||||
|
def __init__(self, ctx: 'Context') -> None:
|
||||||
|
self.context = ctx
|
||||||
|
self.constraints: Dict[Type[TypeConstraintBase], TypeConstraintBase] = {}
|
||||||
|
self.locations: List[str] = []
|
||||||
|
|
||||||
|
def add_constraint(self, newconst: TypeConstraintBase) -> None:
|
||||||
|
if newconst.__class__ in self.constraints:
|
||||||
|
self.constraints[newconst.__class__] = self.constraints[newconst.__class__].narrow(newconst)
|
||||||
|
else:
|
||||||
|
self.constraints[newconst.__class__] = newconst
|
||||||
|
|
||||||
|
def add_location(self, ref: str) -> None:
|
||||||
|
self.locations.append(ref)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
'TypeVar<'
|
||||||
|
+ '; '.join(map(repr, self.constraints.values()))
|
||||||
|
+ '; locations: '
|
||||||
|
+ ', '.join(self.locations)
|
||||||
|
+ '>'
|
||||||
|
)
|
||||||
|
|
||||||
|
class Context:
|
||||||
|
def new_var(self) -> TypeVar:
|
||||||
|
return TypeVar(self)
|
||||||
|
|
||||||
|
def unify(self, l: 'TypeVar', r: 'TypeVar') -> None:
|
||||||
|
newtypevar = self.new_var()
|
||||||
|
|
||||||
|
try:
|
||||||
|
for const in l.constraints.values():
|
||||||
|
newtypevar.add_constraint(const)
|
||||||
|
for const in r.constraints.values():
|
||||||
|
newtypevar.add_constraint(const)
|
||||||
|
except TypingNarrowProtoError as ex:
|
||||||
|
raise TypingNarrowError(l, r, str(ex)) from None
|
||||||
|
|
||||||
|
newtypevar.locations.extend(l.locations)
|
||||||
|
newtypevar.locations.extend(r.locations)
|
||||||
|
|
||||||
|
# Make pointer locations to the constraints and locations
|
||||||
|
# so they get linked together throughout the unification
|
||||||
|
|
||||||
|
l.constraints = newtypevar.constraints
|
||||||
|
l.locations = newtypevar.locations
|
||||||
|
|
||||||
|
r.constraints = newtypevar.constraints
|
||||||
|
r.locations = newtypevar.locations
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def simplify(inp: TypeVar) -> Optional[str]:
|
||||||
|
tc_prim = inp.constraints.get(TypeConstraintPrimitive)
|
||||||
|
tc_bits = inp.constraints.get(TypeConstraintBitWidth)
|
||||||
|
tc_sign = inp.constraints.get(TypeConstraintSigned)
|
||||||
|
|
||||||
|
if tc_prim is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert isinstance(tc_prim, TypeConstraintPrimitive) # type hint
|
||||||
|
primitive = tc_prim.primitive
|
||||||
|
if primitive is TypeConstraintPrimitive.Primitive.INT:
|
||||||
|
if tc_bits is None or tc_sign is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint
|
||||||
|
assert isinstance(tc_sign, TypeConstraintSigned) # type hint
|
||||||
|
|
||||||
|
if tc_sign.signed is None or tc_bits.minb != tc_bits.maxb or tc_bits.minb not in (8, 32, 64):
|
||||||
|
return None
|
||||||
|
|
||||||
|
base = 'i' if tc_sign.signed else 'u'
|
||||||
|
return f'{base}{tc_bits.minb}'
|
||||||
|
|
||||||
|
return None
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Iterable, Optional, TextIO
|
|||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
import io
|
import io
|
||||||
|
import warnings
|
||||||
|
|
||||||
import pywasm.binary
|
import pywasm.binary
|
||||||
import wasm3
|
import wasm3
|
||||||
@ -13,6 +14,7 @@ import wasmtime
|
|||||||
|
|
||||||
from phasm.compiler import phasm_compile
|
from phasm.compiler import phasm_compile
|
||||||
from phasm.parser import phasm_parse
|
from phasm.parser import phasm_parse
|
||||||
|
from phasm.typer import phasm_type
|
||||||
from phasm import ourlang
|
from phasm import ourlang
|
||||||
from phasm import wasm
|
from phasm import wasm
|
||||||
|
|
||||||
@ -40,6 +42,10 @@ class RunnerBase:
|
|||||||
Parses the Phasm code into an AST
|
Parses the Phasm code into an AST
|
||||||
"""
|
"""
|
||||||
self.phasm_ast = phasm_parse(self.phasm_code)
|
self.phasm_ast = phasm_parse(self.phasm_code)
|
||||||
|
try:
|
||||||
|
phasm_type(self.phasm_ast)
|
||||||
|
except NotImplementedError as exc:
|
||||||
|
warnings.warn(f'phasm_type throws an NotImplementedError on this test: {exc}')
|
||||||
|
|
||||||
def compile_ast(self) -> None:
|
def compile_ast(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -304,6 +304,21 @@ def testEntry(a: i32, b: i32) -> i32:
|
|||||||
assert 1 == suite.run_code(10, 20).returned_value
|
assert 1 == suite.run_code(10, 20).returned_value
|
||||||
assert 0 == suite.run_code(10, 10).returned_value
|
assert 0 == suite.run_code(10, 10).returned_value
|
||||||
|
|
||||||
|
@pytest.mark.integration_test
|
||||||
|
def test_call_no_args():
|
||||||
|
code_py = """
|
||||||
|
def helper() -> i32:
|
||||||
|
return 19
|
||||||
|
|
||||||
|
@exported
|
||||||
|
def testEntry() -> i32:
|
||||||
|
return helper()
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = Suite(code_py).run_code()
|
||||||
|
|
||||||
|
assert 19 == result.returned_value
|
||||||
|
|
||||||
@pytest.mark.integration_test
|
@pytest.mark.integration_test
|
||||||
def test_call_pre_defined():
|
def test_call_pre_defined():
|
||||||
code_py = """
|
code_py = """
|
||||||
|
|||||||
31
tests/integration/test_type_checks.py
Normal file
31
tests/integration/test_type_checks.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from phasm.parser import phasm_parse
|
||||||
|
from phasm.typer import phasm_type
|
||||||
|
from phasm.exceptions import TypingError
|
||||||
|
|
||||||
|
@pytest.mark.integration_test
|
||||||
|
def test_constant_too_wide():
|
||||||
|
code_py = """
|
||||||
|
def func_const() -> u8:
|
||||||
|
return 0xFFF
|
||||||
|
"""
|
||||||
|
|
||||||
|
ast = phasm_parse(code_py)
|
||||||
|
with pytest.raises(TypingError, match='Other min bitwidth exceeds max bitwidth'):
|
||||||
|
phasm_type(ast)
|
||||||
|
|
||||||
|
@pytest.mark.integration_test
|
||||||
|
@pytest.mark.parametrize('type_', [32, 64])
|
||||||
|
def test_signed_mismatch(type_):
|
||||||
|
code_py = f"""
|
||||||
|
def func_const() -> u{type_}:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def func_call() -> i{type_}:
|
||||||
|
return func_const()
|
||||||
|
"""
|
||||||
|
|
||||||
|
ast = phasm_parse(code_py)
|
||||||
|
with pytest.raises(TypingError, match='Signed does not match'):
|
||||||
|
phasm_type(ast)
|
||||||
Loading…
x
Reference in New Issue
Block a user