This commit is contained in:
2026-06-24 17:24:04 +02:00
commit 00c38a12d9
41 changed files with 7289 additions and 0 deletions
View File
+135
View File
@@ -0,0 +1,135 @@
from compiler.assembly_generator import Locals, generate_assembly
from compiler.ir import IRVar, LoadIntConst, LoadBoolConst, Copy, CondJump, Label, Instruction, Call
from compiler.tokenizer import L
from typing import List
def test_assembly_generator_locals_initialization() -> None:
variables = [IRVar('x'), IRVar('y'), IRVar('z')]
locals = Locals(variables)
assert locals.get_ref(variables[0]) == '-8(%rbp)'
assert locals.get_ref(variables[1]) == '-16(%rbp)'
assert locals.get_ref(variables[2]) == '-24(%rbp)'
assert locals.stack_used() == 24 # 3 variables * 8 bytes
# def test_assembly_generator_load_int_const() -> None:
# ir_var = IRVar('x')
# instructions: List[Instruction] = [
# LoadIntConst(L, 42, ir_var)
# ]
# asm = generate_assembly(instructions)
# assert 'movq $42, -8(%rbp)' in asm
# def test_assembly_generator_load_bool_const() -> None:
# instructions: List[Instruction] = [
# LoadBoolConst(L, True, IRVar('a')),
# LoadBoolConst(L, False, IRVar('b'))
# ]
# asm = generate_assembly(instructions)
# assert 'movq $1, -8(%rbp)' in asm # True
# assert 'movq $0, -16(%rbp)' in asm # False
# def test_assembly_generator_copy() -> None:
# src = IRVar('src')
# dest = IRVar('dest')
# instructions: List[Instruction] = [
# Copy(L, src, dest)
# ]
# asm = generate_assembly(instructions)
# assert 'movq -8(%rbp), %rax' in asm
# assert 'movq %rax, -16(%rbp)' in asm
# def test_assembly_generator_cond_jump() -> None:
# cond_var = IRVar('cond')
# then_label = Label(L, 'Lthen')
# else_label = Label(L, 'Lelse')
# instructions: List[Instruction] = [
# CondJump(L, cond_var, then_label, else_label)
# ]
# asm = generate_assembly(instructions)
# assert 'cmpq $0, -8(%rbp)' in asm
# assert 'jne .LLthen' in asm
# assert 'jmp .LLelse' in asm
# def test_assembly_generator_function_prologue_epilogue() -> None:
# instructions: List[Instruction] = [
# LoadIntConst(L, 0, IRVar('dummy'))
# ]
# asm = generate_assembly(instructions)
# assert 'pushq %rbp' in asm
# assert 'movq %rsp, %rbp' in asm
# assert 'subq $8, %rsp' in asm
# assert 'movq %rbp, %rsp' in asm
# assert 'popq %rbp' in asm
# assert 'ret' in asm
# def test_assembly_generator_intrinsic_plus() -> None:
# # IR: Call('+', [x, y], result)
# x = IRVar('x')
# y = IRVar('y')
# result = IRVar('result')
# instructions = [
# LoadIntConst(L, 3, x),
# LoadIntConst(L, 5, y),
# Call(L, IRVar('+'), [x, y], result)
# ]
# asm = generate_assembly(instructions)
# assert 'addq' in asm
# assert 'movq' in asm
# assert 'callq' not in asm # Intrinsic should not use call
# def test_assembly_generator_intrinsic_divide() -> None:
# # IR: Call('/', [a, b], result)
# a = IRVar('a')
# b = IRVar('b')
# instructions = [
# LoadIntConst(L, 10, a),
# LoadIntConst(L, 2, b),
# Call(L, IRVar('/'), [a, b], IRVar('result'))
# ]
# asm = generate_assembly(instructions)
# assert 'idivq' in asm
# assert 'cqto' in asm
# def test_assembly_generator_function_call_one_arg() -> None:
# # IR: Call(print_int, [x], _)
# x = IRVar('x')
# instructions = [
# LoadIntConst(L, 42, x),
# Call(L, IRVar('print_int'), [x], IRVar('unused'))
# ]
# asm = generate_assembly(instructions)
# assert 'movq -8(%rbp), %rdi' in asm # Assuming x is at -8(%rbp)
# assert 'callq print_int' in asm
# def test_assembly_generator_function_call_six_args() -> None:
# # IR: Call(func, [a,b,c,d,e,f], _)
# args = [IRVar(f'arg{i}') for i in range(6)]
# instructions = [
# *[LoadIntConst(L, i, arg) for i, arg in enumerate(args)],
# Call(L, IRVar('func'), args, IRVar('result'))
# ]
# asm = generate_assembly(instructions)
# expected_regs = ['%rdi', '%rsi', '%rdx', '%rcx', '%r8', '%r9']
# for i, reg in enumerate(expected_regs):
# assert f'movq -{8*(i+1)}(%rbp), {reg}' in asm
# assert 'callq func' in asm
# def test_assembly_generator_comparison_intrinsic() -> None:
# # IR: Call('==', [x, y], result)
# x = IRVar('x')
# y = IRVar('y')
# instructions = [
# LoadIntConst(L, 5, x),
# LoadIntConst(L, 5, y),
# Call(L, IRVar('=='), [x, y], IRVar('result'))
# ]
# asm = generate_assembly(instructions)
# assert 'cmpq' in asm
# assert 'sete' in asm
+419
View File
@@ -0,0 +1,419 @@
import pytest
from typing import Any
from compiler.parser import parse
from compiler.tokenizer import tokenize
from compiler.interpreter import interpret, SymTab, Unit
def test_interpreter_addition() -> None:
exp = '2 + 3'
tokens = tokenize(exp)
ast = parse(tokens)
assert interpret(ast.block) == 5
def test_interpreter_boolean_coercion() -> None:
exp = '1 + (2 < 3)'
tokens = tokenize(exp)
ast = parse(tokens)
assert interpret(ast.block) == 2
def test_interpreter_nested_blocks() -> None:
code = '''
{
var x = 5;
{
var y = 10;
x + y
}
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 15
def test_interpreter_variable_shadowing() -> None:
code = '''
var x = 1;
{
var x = 2;
x
}
x
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 1
def test_interpreter_assignment_in_outer_scope() -> None:
code = '''
var x = 5;
{
x = x + 1;
x
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 6
def test_interpreter_assignment_operator() -> None:
code = 'var x = 0; x = 5; x'
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 5
def test_interpreter_operator_precedence() -> None:
code = '10 - 4 - 3' # (10-4)-3 = 3
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 3
def test_interpreter_logical_operators() -> None:
code = 'true and false or true' # (true and false) or true = true
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_comparison_equality() -> None:
code = '5 > 3 == true' # (5>3) == true → true
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_unary_operators() -> None:
code = 'not (5 < 3)' # not(false) → true
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_block_scoping() -> None:
code = '''
{
var a = 10;
{
var b = 20;
a + b
}
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 30
def test_interpreter_variable_redeclaration() -> None:
code = '''
{
var x = 1;
var x = 2;
x
}
'''
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_while_loop() -> None:
code = '''
var counter = 3;
while counter > 0 do {
counter = counter - 1
};
counter
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 0
def test_interpreter_empty_block() -> None:
code = '{}'
tokens = tokenize(code)
ast = parse(tokens)
assert isinstance(interpret(ast.block), Unit)
def test_interpreter_block_last_semicolon() -> None:
code = '''
{
var x = 5;
x;
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert isinstance(interpret(ast.block), Unit)
def test_interpreter_if_without_else_returns_unit() -> None:
code = 'if false then 5'
tokens = tokenize(code)
ast = parse(tokens)
assert isinstance(interpret(ast.block), Unit)
def test_interpreter_if_else_true_branch() -> None:
code = 'if true then 10 else 20'
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 10
def test_interpreter_if_else_false_branch() -> None:
code = 'if false then 10 else 20'
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 20
def test_interpreter_function_call_raises_error(capsys: pytest.CaptureFixture) -> None:
code = 'print_int(5)'
tokens = tokenize(code)
ast = parse(tokens)
assert isinstance(interpret(ast.block), Unit)
captured = capsys.readouterr()
assert captured.out == '5\n'
def test_interpreter_division_modulo_operators() -> None:
code = '8 / 3 + 8 % 3' # 2 + 2 = 4
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 4
def test_interpreter_assignment_to_non_identifier_raises() -> None:
code = '5 = 10'
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_while_loop_zero_iterations() -> None:
code = '''
var x = 3;
while x > 5 do {
x = x - 1
};
x
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 3
def test_interpreter_nested_if_expression() -> None:
code = '(if true then 5 else 3) + 2'
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 7
def test_interpreter_unary_not_on_boolean() -> None:
code = 'not true'
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is False
def test_interpreter_unary_negate_non_int_raises() -> None:
code = '-true'
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_and_short_circuit() -> None:
code = '''
var evaluated = false;
false and { evaluated = true; true };
evaluated
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is False
def test_interpreter_and_evaluates_both() -> None:
code = '''
var evaluated = false;
true and { evaluated = true; true };
evaluated
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_or_short_circuit() -> None:
code = '''
var evaluated = false;
true or { evaluated = true; true };
evaluated
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is False
def test_interpreter_or_evaluates_both() -> None:
code = '''
var evaluated = false;
false or { evaluated = true; true };
evaluated
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_unary_negate_int() -> None:
code = '-5'
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == -5
def test_interpreter_unary_negate_bool_raises() -> None:
code = '-true'
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_unary_not_on_int_raises() -> None:
code = 'not 5'
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_binary_operator_precedence() -> None:
code = '1 + 2 * 3 == 7' # 1 + (2 * 3) == 7 → true
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_binary_operator_associativity() -> None:
code = '10 - 4 - 3' # (10 - 4) - 3 = 3
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 3
def test_interpreter_binary_operator_equality() -> None:
code = '5 == 5 and 3 != 4' # true and true → true
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_binary_operator_comparison() -> None:
code = '5 < 10 and 10 >= 10' # true and true → true
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_binary_operator_division_by_zero_raises() -> None:
code = '5 / 0'
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_binary_operator_modulo_by_zero_raises() -> None:
code = '5 % 0'
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_while_loop_multiple_iterations() -> None:
code = '''
var x = 3;
while x > 0 do {
x = x - 1
};
x
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 0
def test_interpreter_block_with_multiple_statements() -> None:
code = '''
{
var x = 5;
var y = 10;
x + y
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 15
def test_interpreter_block_with_trailing_semicolon() -> None:
code = '''
{
var x = 5;
x;
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert isinstance(interpret(ast.block), Unit)
def test_interpreter_block_with_nested_blocks() -> None:
code = '''
{
var x = 5;
{
var y = 10;
x + y
}
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 15
def test_interpreter_variable_declaration_in_block() -> None:
code = '''
{
var x = 5;
x
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 5
def test_interpreter_variable_redeclaration_in_block_raises() -> None:
code = '''
{
var x = 5;
var x = 10;
x
}
'''
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_assignment_in_block() -> None:
code = '''
{
var x = 5;
x = x + 1;
x
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 6
def test_interpreter_function_call_print_int() -> None:
code = 'print_int(5)'
tokens = tokenize(code)
ast = parse(tokens)
assert isinstance(interpret(ast.block), Unit)
def test_interpreter_function_call_print_bool() -> None:
code = 'print_bool(true)'
tokens = tokenize(code)
ast = parse(tokens)
assert isinstance(interpret(ast.block), Unit)
def test_interpreter_function_call_read_int() -> None:
code = 'read_int()'
tokens = tokenize(code)
ast = parse(tokens)
# Mock input for read_int
import builtins
original_input: Any = builtins.input
func: Any = lambda: '42'
builtins.input = func
assert interpret(ast.block) == 42
builtins.input = original_input
+12
View File
@@ -0,0 +1,12 @@
from compiler.type_checker import typecheck_module
from compiler.tokenizer import L, tokenize
from compiler.parser import parse
from compiler.ir_generator import generate_ir_from_module, print_instructions, root_types
def test_ir_generator_basic () -> None:
expr_str = '1 + 2 * 3'
tokens = tokenize(expr_str)
ast = parse(tokens)
typecheck_module(ast)
main_instructions = generate_ir_from_module(ast)['main']
assert print_instructions(main_instructions) != ''
+2505
View File
File diff suppressed because it is too large Load Diff
+28
View File
@@ -0,0 +1,28 @@
from compiler.tokenizer import tokenize, Token, L
def test_tokenizer_basic() -> None:
assert tokenize('aaa 123 bbb') == [
Token(location=L, type='identifier', text='aaa'),
Token(location=L, type='int_literal', text='123'),
Token(location=L, type='identifier', text='bbb'),
]
def test_tokenizer_newline() -> None:
assert tokenize('if 3\nwhile') == [
Token(location=L, type='identifier', text='if'),
Token(location=L, type='int_literal', text='3'),
Token(location=L, type='identifier', text='while'),
]
def test_tokenizer_commments() -> None:
assert tokenize('aaa 123 bbb ; ) ( >= \n ) # aksdjalksjdkajskdjasd\n != // Another comment $') == [
Token(location=L, type='identifier', text='aaa'),
Token(location=L, type='int_literal', text='123'),
Token(location=L, type='identifier', text='bbb'),
Token(location=L, type='punctuation', text=';'),
Token(location=L, type='punctuation', text=')'),
Token(location=L, type='punctuation', text='('),
Token(location=L, type='operator', text='>='),
Token(location=L, type='punctuation', text=')'),
Token(location=L, type='operator', text='!='),
]
+921
View File
@@ -0,0 +1,921 @@
import pytest
from compiler.type_checker import typecheck, TypeSymTab
from compiler import ast
from compiler.types import Int_Instance, Bool_Instance, Unit_Instance, FunType
from compiler.tokenizer import L
def test_type_checker_assignment_with_unknown_variable() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.Literal(L, 5)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_literal_int() -> None:
node = ast.Literal(location=L, value=5)
result = typecheck(node)
assert result == Int_Instance
assert node.type == Int_Instance
def test_type_checker_literal_bool() -> None:
node = ast.Literal(location=L, value=True)
result = typecheck(node)
assert result == Bool_Instance
assert node.type == Bool_Instance
def test_type_checker_literal_unit() -> None:
node = ast.Literal(location=L, value=None)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_variable_lookup() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.Identifier(location=L, name='x')
result = typecheck(node, sym_tab)
assert result == Int_Instance
assert node.type == Int_Instance
def test_type_checker_variable_undefined() -> None:
node = ast.Identifier(location=L, name='x')
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_binary_op_add_valid() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, 1),
op='+',
right=ast.Literal(L, 2)
)
result = typecheck(node)
assert result == Int_Instance
assert node.type == Int_Instance
def test_type_checker_binary_op_add_invalid() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, True),
op='+',
right=ast.Literal(L, 2)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_unary_op_negate_valid() -> None:
node = ast.UnaryOp(
location=L,
op='-',
right=ast.Literal(L, 5)
)
result = typecheck(node)
assert result == Int_Instance
assert node.type == Int_Instance
def test_type_checker_unary_op_not_valid() -> None:
node = ast.UnaryOp(
location=L,
op='not',
right=ast.Literal(L, False)
)
result = typecheck(node)
assert result == Bool_Instance
assert node.type == Bool_Instance
def test_type_checker_function_call_builtin() -> None:
node = ast.FunctionCall(
location=L,
name='print_int',
args=[ast.Literal(L, 5)]
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_block_last_expr() -> None:
node = ast.Block(
location=L,
statements=[
ast.Literal(L, 1),
ast.Literal(L, 2)
]
)
result = typecheck(node)
assert result == Int_Instance
assert node.statements[-1].type == Int_Instance
def test_type_checker_block_unit() -> None:
node = ast.Block(
location=L,
statements=[
ast.Literal(L, 1),
ast.Literal(L, None)
]
)
result = typecheck(node)
assert result == Unit_Instance
assert node.statements[-1].type == Unit_Instance
def test_type_checker_if_then_else_valid() -> None:
node = ast.If(
location=L,
cond_exp=ast.Literal(L, True),
then_exp=ast.Literal(L, 1),
else_exp=ast.Literal(L, 2)
)
result = typecheck(node)
assert result == Int_Instance
assert node.then_exp.type == Int_Instance
def test_type_checker_if_then_else_mismatch() -> None:
node = ast.If(
location=L,
cond_exp=ast.Literal(L, True),
then_exp=ast.Literal(L, 1),
else_exp=ast.Literal(L, False)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_while_loop() -> None:
node = ast.While(
location=L,
while_exp=ast.Literal(L, True),
do_exp=ast.Literal(L, 5)
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_assignment_valid() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.Literal(L, 5)
)
result = typecheck(node, sym_tab)
assert result == Int_Instance
assert node.right.type == Int_Instance
def test_type_checker_assignment_mismatch() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.Literal(L, False)
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_var_declaration_inferred() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, 5),
type_f=None
)
result = typecheck(node)
assert result == Int_Instance
assert node.value.type == Int_Instance
def test_type_checker_var_declaration_typed_valid() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, 5),
type_f=Int_Instance
)
result = typecheck(node)
assert result == Int_Instance
assert node.value.type == Int_Instance
def test_type_checker_var_declaration_typed_mismatch() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, True),
type_f=Int_Instance
)
with pytest.raises(Exception):
typecheck(node)
# def test_type_checker_function_type_params() -> None:
# node = ast.FunctionCall(
# location=L,
# name='+',
# args=[
# ast.Literal(L, 1),
# ast.Literal(L, 2)
# ]
# )
# result = typecheck(node)
# assert result == Int_Instance
# assert node.type == Int_Instance
def test_type_checker_symbol_table_hierarchy() -> None:
parent = TypeSymTab(locals={'x': Int_Instance})
child = TypeSymTab(parent=parent)
node = ast.Identifier(L, 'x')
result = typecheck(node, child)
assert result == Int_Instance
assert node.type == Int_Instance
def test_type_checker_equality_operator_valid() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, 5),
op='==',
right=ast.Literal(L, 5)
)
result = typecheck(node)
assert result == Bool_Instance
assert node.type == Bool_Instance
def test_type_checker_equality_operator_mismatch() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, 5),
op='==',
right=ast.Literal(L, False)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_function_call_user_defined() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[Int_Instance], result=Bool_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[ast.Literal(L, 5)]
)
result = typecheck(node, sym_tab)
assert result == Bool_Instance
assert node.type == Bool_Instance
def test_type_checker_nested_blocks() -> None:
node = ast.Block(
location=L,
statements=[
ast.Block(
location=L,
statements=[
ast.Literal(L, 5),
ast.Block(
location=L,
statements=[
ast.Literal(L, True)
]
)
]
)
]
)
result = typecheck(node)
assert result == Bool_Instance
assert isinstance(node.statements[0], ast.Block)
assert node.statements[0].statements[-1].type == Bool_Instance
def test_type_checker_function_type_annotation() -> None:
node = ast.Var(
location=L,
name='f',
value=ast.Literal(L, 5),
type_f=FunType(params=[Int_Instance], result=Bool_Instance)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_function_call_with_wrong_arg_count() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[Int_Instance, Int_Instance], result=Bool_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[ast.Literal(L, 5)]
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_function_call_with_wrong_arg_type() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[Int_Instance], result=Bool_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[ast.Literal(L, True)]
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_function_call_with_nested_args() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[Int_Instance, Bool_Instance], result=Bool_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[
ast.Literal(L, 5),
ast.BinaryOp(
location=L,
left=ast.Literal(L, True),
op='and',
right=ast.Literal(L, False)
)
]
)
result = typecheck(node, sym_tab)
assert result == Bool_Instance
assert node.type == Bool_Instance
def test_type_checker_function_call_with_builtin_print_int() -> None:
node = ast.FunctionCall(
location=L,
name='print_int',
args=[ast.Literal(L, 5)]
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_function_call_with_builtin_read_int() -> None:
node = ast.FunctionCall(
location=L,
name='read_int',
args=[]
)
result = typecheck(node)
assert result == Int_Instance
assert node.type == Int_Instance
def test_type_checker_function_call_with_builtin_print_bool() -> None:
node = ast.FunctionCall(
location=L,
name='print_bool',
args=[ast.Literal(L, True)]
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_function_call_with_unknown_function() -> None:
node = ast.FunctionCall(
location=L,
name='unknown_func',
args=[ast.Literal(L, 5)]
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_if_then_without_else() -> None:
node = ast.If(
location=L,
cond_exp=ast.Literal(L, True),
then_exp=ast.Literal(L, 5),
else_exp=None
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_if_then_with_else() -> None:
node = ast.If(
location=L,
cond_exp=ast.Literal(L, True),
then_exp=ast.Literal(L, 5),
else_exp=ast.Literal(L, 10)
)
result = typecheck(node)
assert result == Int_Instance
assert node.type == Int_Instance
def test_type_checker_if_then_with_else_mismatch() -> None:
node = ast.If(
location=L,
cond_exp=ast.Literal(L, True),
then_exp=ast.Literal(L, 5),
else_exp=ast.Literal(L, False)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_while_loop_with_non_bool_condition() -> None:
node = ast.While(
location=L,
while_exp=ast.Literal(L, 5),
do_exp=ast.Literal(L, 10)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_while_loop_with_unit_body() -> None:
node = ast.While(
location=L,
while_exp=ast.Literal(L, True),
do_exp=ast.Literal(L, None)
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_block_with_multiple_statements() -> None:
node = ast.Block(
location=L,
statements=[
ast.Literal(L, 5),
ast.Literal(L, True),
ast.Literal(L, None)
]
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_block_with_nested_blocks() -> None:
node = ast.Block(
location=L,
statements=[
ast.Block(
location=L,
statements=[
ast.Literal(L, 5),
ast.Block(
location=L,
statements=[
ast.Literal(L, True)
]
)
]
)
]
)
result = typecheck(node)
assert result == Bool_Instance
assert node.type == Bool_Instance
def test_type_checker_block_with_empty_statements() -> None:
node = ast.Block(
location=L,
statements=[]
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_var_declaration_with_nested_expression() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.BinaryOp(
location=L,
left=ast.Literal(L, 5),
op='+',
right=ast.Literal(L, 10)
),
type_f=None
)
result = typecheck(node)
assert result == Int_Instance
assert node.value.type == Int_Instance
def test_type_checker_var_declaration_with_typed_initializer() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, 5),
type_f=Int_Instance
)
result = typecheck(node)
assert result == Int_Instance
assert node.value.type == Int_Instance
def test_type_checker_var_declaration_with_typed_initializer_mismatch() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, True),
type_f=Int_Instance
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_var_declaration_with_untyped_initializer() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, 5),
type_f=None
)
result = typecheck(node)
assert result == Int_Instance
assert node.value.type == Int_Instance
def test_type_checker_var_declaration_with_function_type() -> None:
node = ast.Var(
location=L,
name='f',
value=ast.Literal(L, 5),
type_f=FunType(params=[Int_Instance], result=Bool_Instance)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_assignment_with_function_call() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.FunctionCall(
location=L,
name='read_int',
args=[]
)
)
result = typecheck(node, sym_tab)
assert result == Int_Instance
assert node.right.type == Int_Instance
def test_type_checker_assignment_with_function_call_mismatch() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.FunctionCall(
location=L,
name='print_int',
args=[ast.Literal(L, 5)]
)
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_assignment_with_nested_expression() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.BinaryOp(
location=L,
left=ast.Literal(L, 5),
op='+',
right=ast.Literal(L, 10)
)
)
result = typecheck(node, sym_tab)
assert result == Int_Instance
assert node.right.type == Int_Instance
def test_type_checker_assignment_with_nested_expression_mismatch() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.BinaryOp(
location=L,
left=ast.Literal(L, 5),
op='+',
right=ast.Literal(L, True)
)
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_assignment_with_wrong_type() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.Literal(L, True)
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_assignment_with_function_type() -> None:
sym_tab = TypeSymTab(locals={'f': FunType(params=[Int_Instance], result=Bool_Instance)})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'f'),
op='=',
right=ast.Literal(L, 5)
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_assignment_with_function_call_returning_function() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[], result=FunType(params=[Int_Instance], result=Bool_Instance))
})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.FunctionCall(
location=L,
name='f',
args=[]
)
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_assignment_with_function_call_returning_int() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[], result=Int_Instance),
'x': Int_Instance
})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.FunctionCall(
location=L,
name='f',
args=[]
)
)
result = typecheck(node, sym_tab)
assert result == Int_Instance
assert node.right.type == Int_Instance
def test_type_checker_assignment_with_function_call_returning_bool() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[], result=Bool_Instance),
'x': Int_Instance
})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.FunctionCall(
location=L,
name='f',
args=[]
)
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
import pytest
from compiler.type_checker import typecheck, TypeSymTab
from compiler import ast
from compiler.types import Int_Instance, Bool_Instance, Unit_Instance, FunType
from compiler.tokenizer import L
def test_type_checker_var_typed_declaration_valid() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, 5),
type_f=Int_Instance
)
result = typecheck(node)
assert result == Int_Instance
assert node.value.type == Int_Instance
def test_type_checker_var_typed_declaration_mismatch() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, True),
type_f=Int_Instance
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_assignment_non_identifier() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, 5), # Invalid left side
op='=',
right=ast.Literal(L, 10)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_user_function_call_args_count_mismatch() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[Int_Instance, Int_Instance], result=Int_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[ast.Literal(L, 5)] # Missing second arg
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_user_function_call_arg_type_mismatch() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[Int_Instance], result=Bool_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[ast.Literal(L, True)] # Bool instead of Int
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_if_then_returns_unit() -> None:
node = ast.If(
location=L,
cond_exp=ast.Literal(L, True),
then_exp=ast.Literal(L, 5),
else_exp=None
)
result = typecheck(node)
assert result == Unit_Instance
def test_type_checker_equality_mixed_types() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, 5),
op='==',
right=ast.Literal(L, False)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_logical_op_non_boolean() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, 5),
op='and',
right=ast.Literal(L, 0)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_nested_scope_shadowing() -> None:
outer = TypeSymTab(locals={'x': Int_Instance})
inner = TypeSymTab(parent=outer)
inner.locals['x'] = Bool_Instance # Shadow outer x
node = ast.Identifier(L, 'x')
result = typecheck(node, inner)
assert result == Bool_Instance
def test_type_checker_function_type_multiple_params() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[Int_Instance, Bool_Instance], result=Unit_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[
ast.Literal(L, 5),
ast.Literal(L, True)
]
)
result = typecheck(node, sym_tab)
assert result == Unit_Instance
def test_type_checker_read_int_builtin() -> None:
node = ast.FunctionCall(
location=L,
name='read_int',
args=[]
)
result = typecheck(node)
assert result == Int_Instance
def test_type_checker_block_trailing_semicolon_returns_unit() -> None:
node = ast.Block(
location=L,
statements=[
ast.Literal(L, 5),
ast.Literal(L, None) # Trailing semicolon case
]
)
result = typecheck(node)
assert result == Unit_Instance
def test_type_checker_not_operator_non_boolean() -> None:
node = ast.UnaryOp(
location=L,
op='not',
right=ast.Literal(L, 5) # Int instead of Bool
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_equality_operator_same_types() -> None:
for val in [5, True]:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, val),
op='==',
right=ast.Literal(L, val)
)
result = typecheck(node)
assert result == Bool_Instance
def test_type_checker_function_type_nested_params() -> None:
# Test function type with complex parameter structure
fun_type = FunType(
params=[[Int_Instance, Bool_Instance], Int_Instance], # From spec's TODO
result=Unit_Instance
)
sym_tab = TypeSymTab(locals={'f': fun_type})
# Valid call with Int then Int (first param allows Int or Bool)
node = ast.FunctionCall(
location=L,
name='f',
args=[ast.Literal(L, 5), ast.Literal(L, 5)]
)
with pytest.raises(Exception): # Second arg should be Int, but first allows Int
typecheck(node, sym_tab) # Actual logic may vary based on implementation
def test_type_checker_operator_precedence_type_resolution() -> None:
# Ensure operator precedence doesn't affect type checking
node = ast.BinaryOp(
location=L,
left=ast.BinaryOp(
location=L,
left=ast.Literal(L, 3),
op='*',
right=ast.Literal(L, 4)
),
op='+',
right=ast.Literal(L, 5)
)
result = typecheck(node)
assert result == Int_Instance
def test_type_checker_while_loop_condition_non_boolean() -> None:
node = ast.While(
location=L,
while_exp=ast.Literal(L, 5), # Non-bool condition
do_exp=ast.Literal(L, 0)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_builtin_print_bool_type() -> None:
node = ast.FunctionCall(
location=L,
name='print_bool',
args=[ast.Literal(L, True)]
)
result = typecheck(node)
assert result == Unit_Instance
def test_type_checker_function_return_type_mismatch() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[], result=Int_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[]
)
result = typecheck(node, sym_tab)
assert result == Int_Instance # Dummy test to be replaced with actual logic
def test_type_checker_complex_function_type_annotation() -> None:
# Test type annotation like (Int, (Bool) => Unit) => Int
fun_type = FunType(
params=[
Int_Instance,
FunType(params=[Bool_Instance], result=Unit_Instance)
],
result=Int_Instance
)
sym_tab = TypeSymTab(locals={'f': fun_type})
node = ast.FunctionCall(
location=L,
name='f',
args=[
ast.Literal(L, 5),
ast.Identifier(L, 'g') # Assume 'g' has correct type
]
)
# Actual test would require 'g' to be in symbol table
with pytest.raises(Exception): # 'g' not defined here
typecheck(node, sym_tab)
def test_type_checker_undefined_function_call() -> None:
node = ast.FunctionCall(
location=L,
name='undefined_func',
args=[]
)
with pytest.raises(Exception):
typecheck(node)