From 00c38a12d92fbb9c82ae543aa755a8035737d3dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=5BQuy=20Anh=5D=20=C2=ABElliot=C2=BB=20Nguyen?= Date: Wed, 24 Jun 2026 17:24:04 +0200 Subject: [PATCH] --- --- .dockerignore | 8 + .gitignore | 7 + .python-version | 1 + Dockerfile | 18 + README.md | 53 + check.sh | 6 + compiler.sh | 3 + exercises/1.py | 8 + exercises/6.txt | 22 + exercises/7.txt | 3 + exercises/a + b.drawio | 25 + exercises/asmprogram | Bin 0 -> 16712 bytes exercises/asmprogram.s | 55 + exercises/f(a * f(b)) + c.drawio | 49 + exercises/f(a + b, b + c).drawio | 49 + exercises/f(f(a)).drawio | 25 + exercises/while i < 100 do i = i + 1.drawio | 61 + mypy.ini | 3 + poetry.lock | 173 ++ pyproject.toml | 31 + src/compiler/__init__.py | 0 src/compiler/__main__.py | 117 + src/compiler/assembler.py | 373 +++ src/compiler/assembly_generator.py | 374 +++ src/compiler/ast.py | 85 + src/compiler/interpreter.py | 311 +++ src/compiler/intrinsics.py | 113 + src/compiler/ir.py | 86 + src/compiler/ir_generator.py | 424 ++++ src/compiler/parser.py | 337 +++ src/compiler/tokenizer.py | 88 + src/compiler/type_checker.py | 301 +++ src/compiler/types.py | 29 + test-gadget.py | 31 + tests/__init__.py | 0 tests/assembly_generator_test.py | 135 + tests/interpreter_test.py | 419 ++++ tests/ir_generator_test.py | 12 + tests/parser_test.py | 2505 +++++++++++++++++++ tests/tokenizer_test.py | 28 + tests/type_checker_test.py | 921 +++++++ 41 files changed, 7289 insertions(+) create mode 100644 .dockerignore create mode 100644 .gitignore create mode 100644 .python-version create mode 100644 Dockerfile create mode 100644 README.md create mode 100755 check.sh create mode 100755 compiler.sh create mode 100644 exercises/1.py create mode 100644 exercises/6.txt create mode 100644 exercises/7.txt create mode 100644 exercises/a + b.drawio create mode 100755 exercises/asmprogram create mode 100644 exercises/asmprogram.s create mode 100644 exercises/f(a * f(b)) + c.drawio create mode 100644 exercises/f(a + b, b + c).drawio create mode 100644 exercises/f(f(a)).drawio create mode 100644 exercises/while i < 100 do i = i + 1.drawio create mode 100644 mypy.ini create mode 100644 poetry.lock create mode 100644 pyproject.toml create mode 100644 src/compiler/__init__.py create mode 100644 src/compiler/__main__.py create mode 100644 src/compiler/assembler.py create mode 100644 src/compiler/assembly_generator.py create mode 100644 src/compiler/ast.py create mode 100644 src/compiler/interpreter.py create mode 100644 src/compiler/intrinsics.py create mode 100644 src/compiler/ir.py create mode 100644 src/compiler/ir_generator.py create mode 100644 src/compiler/parser.py create mode 100644 src/compiler/tokenizer.py create mode 100644 src/compiler/type_checker.py create mode 100644 src/compiler/types.py create mode 100755 test-gadget.py create mode 100644 tests/__init__.py create mode 100644 tests/assembly_generator_test.py create mode 100644 tests/interpreter_test.py create mode 100644 tests/ir_generator_test.py create mode 100644 tests/parser_test.py create mode 100644 tests/tokenizer_test.py create mode 100644 tests/type_checker_test.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..bfa869a --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +.test-gadget +.*_cache +*.pyc +__pycache__ +.git +exercises +program +program.src diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f962ccb --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +*.pyc +__pycache__ +/.mypy_cache +/.pytest_cache +/.test-gadget +program +program.src diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..450178b --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12.7 \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..8d3a750 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,18 @@ +FROM python:3.12-alpine3.21 + +RUN pip install --no-cache-dir poetry==2.0.0 +RUN apk add --no-cache bash binutils + +WORKDIR /compiler +COPY pyproject.toml poetry.lock . +RUN poetry install --no-root --no-cache + +COPY . . + +RUN poetry install --no-cache && ./check.sh + +EXPOSE 3000 + +# Setting PYTHONUNBUFFERED forces print() to flush even if stdout is not a TTY. +ENV PYTHONUNBUFFERED=1 +CMD ["./compiler.sh", "serve", "--host=0.0.0.0"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..cdfdef8 --- /dev/null +++ b/README.md @@ -0,0 +1,53 @@ + +## Setup + +Requirements: + +- [Pyenv](https://github.com/pyenv/pyenv) for installing Python 3.12+ + - Recommended installation method: the "automatic installer" + i.e. `curl https://pyenv.run | bash` + - Follow the instructions at the end of the output to make pyenv available in your shell. + You may need to restart your shell or even log out and log in again to make + the `pyenv` command available. +- [Poetry](https://python-poetry.org/) for installing dependencies + - Recommended installation method: the "official installer" + i.e. `curl -sSL https://install.python-poetry.org | python3 -` + +Install dependencies: + + # Install Python specified in `.python-version` + pyenv install + # Install dependencies specified in `pyproject.toml` + poetry install + +If `pyenv install` gives an error about `_tkinter`, you can ignore it. +If you see other errors, you may have to investigate. + +If you have trouble with Poetry not picking up pyenv's python installation, +try `poetry env remove --all` and then `poetry install` again. + +Typecheck and run local unit tests: + + ./check.sh + # or individually: + poetry run mypy . + poetry run pytest -vv + +Once you've finished your compiler, edit `src/__main__.py` to call your compiler in function `call_compiler`. +Then you can run your compiler on a source code file like this: + + ./compiler.sh compile path/to/source/code --output=path/to/output/file + +You can send the finished compiler to Test Gadget for evaluation with: + + ./test-gadget.py submit + +See the course page for more information. + +## IDE setup + +Recommended VSCode extensions: + +- Python +- Pylance +- autopep8 diff --git a/check.sh b/check.sh new file mode 100755 index 0000000..b905a2a --- /dev/null +++ b/check.sh @@ -0,0 +1,6 @@ +#!/bin/bash +set -euo pipefail +cd "$(dirname "${0}")" +poetry run mypy . +rm -Rf test_programs/workdir +poetry run pytest -vv tests/ diff --git a/compiler.sh b/compiler.sh new file mode 100755 index 0000000..c0a81d1 --- /dev/null +++ b/compiler.sh @@ -0,0 +1,3 @@ +#!/bin/bash +set -euo pipefail +exec poetry -C "$(dirname "${0}")" run main "$@" diff --git a/exercises/1.py b/exercises/1.py new file mode 100644 index 0000000..8c5d461 --- /dev/null +++ b/exercises/1.py @@ -0,0 +1,8 @@ +import re + +re_1 = r'hello!*' +re_2 = r'hello!* +there' +re_3 = r'(\+|\-|\*|\/)' +re_4 = r'\-?[0-9]+' +re_5 = r'\"([a-z]+ *)+\"' +re_6 = r'\"([[a-z]|\\.]+ *)+\"' diff --git a/exercises/6.txt b/exercises/6.txt new file mode 100644 index 0000000..9c7a000 --- /dev/null +++ b/exercises/6.txt @@ -0,0 +1,22 @@ +// a * b +Call(*, [a, b], x1) + +// f(g(x + 1)) +Call(+, [x, 1], x1) +Call(g, [x2], x2) +Call(f, [x3], x3) + +// { f(x); f(y); } +Call(f, [x], x1) +Call(f, [y], x2) + +// while a < b do f() +L0 +Call(<, [a, b], x1) +CondJump(x1, L1, L2) + +L1 +Call(f, [], x2) +Jump(L0) + +L2 diff --git a/exercises/7.txt b/exercises/7.txt new file mode 100644 index 0000000..86eac4d --- /dev/null +++ b/exercises/7.txt @@ -0,0 +1,3 @@ +addq %rax, %rbx +addq %rbx, %rcx +movq %rcx, 48(%rdx) diff --git a/exercises/a + b.drawio b/exercises/a + b.drawio new file mode 100644 index 0000000..9945a18 --- /dev/null +++ b/exercises/a + b.drawio @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/exercises/asmprogram b/exercises/asmprogram new file mode 100755 index 0000000000000000000000000000000000000000..9ee43c2f4d3cbfd53cebe60ac5b8cb7744d07c06 GIT binary patch literal 16712 zcmeHOU2Ggz6~43UG@Czroe+{Hg-n7YOohikajZabJa%Gd+&GCzTvUph(Rg>(9<+a* zoo($>sjgClOauxdLP3he3q^iNMUZ%?grEt57zz)L5K@2#(uzv18roD8WI)Pt&b{Ys z##39geLzB*Bh8$jd%kn-oVhc5XJ_uSgCj!;O%qHBaZsQ(nl`0L?S-%oNR?~{M58c8 zi}-}N4=fF^nX(1Bl9;EHs)2cm{LqZJ^VEJmY04Jn4@gZ2=G;T%h`UE7iHaXsqg3W1 zt#ri6Cej5RAM+O|0rN0U=F?Rip_@t|{wp#w596>lOKr+wsn?VO^VWp2cMxYBs)$P& zsy%hklmhbw>L+uqC-xDyQL5)?QsJ2E(1xVTJ;W!@xH93&gk#=9e$3GyiE(qfRK*a%21MeE*;3A3poU=}RX@-r4Zl9|y1gFyEf{#(TQk z^Vzn%SFBF8O&#cI>*+QsC8JBm!B`v*#P>y5gUu^=-oJ{yAME-H{+M^AIHHnqiaBB1 z(D#g82^>GL3yxP5W#20XIe}(8f*x^XWO(3^-DPwc`|I+)Za_{HN<~82kR<^%(qx^m z8ye_3a6UiP9o{;u;)7)S@wdJMmrX+SfR1#t&y*amCG0cPt&p9DylOH{$+1p!YNehE z*?COKhIkiOAg(}Mfw%&31>y?C6^JVkSK$A!0@j(^DeLT--?zgf{#UOAiQ0m7<`=1J z^&_e6e+8Oqdw&c2rd=jTs83@18@Fmedw-5PPO;xu0JZb4!?QO|TeGiOXZ~_$bZoHm zO6RYw`T3q^fNq@t*LQP9)2`=be2jN3tR5EPB^@T*Rxj>@Eia)Pc5B_sXsF%Fox7F` z15EKV>HpkYlFi<+uD;c0U0q6A+I8#3-JrQ1VJ$_7;{Vq4@qf+qZiSAC>Xt{W+1D4J zhNHN?CJ7ep@oVqNhyHJV9UZObZqH!%=I77B!CBv^E-k)&w^p;xENRuYd)7OPe_-o2 z>+Il?l^nbSt>+OldwuZ|HqKvxK%m}!{Mz1pR=~Di8t~jKI~!y?C z6^JVkS0JuHT!FX(aRuTEeDDh3do^um@*y3d_&)4`JGI&~z(;_84)`;`@o695St4mq zpA_0uTHCpCZR(u1HjVG&@Ev^${Pv&&(8|8*-5vUF1Ydq9r-R-B+l9AlHRPf(J=D1U z;ie5wrDjCmj)xxH*CNB>yEn{*IDBu4g#ArODc;2uh$|3RAg(}Mfw%&31>y?C6^JVk zS0JvyY8B{!_m7f@ho8sLzEQQpLzgL~O{8BWfBbF~iN9ywsnq&+#XCqB2TdvAa~~3a zcl`IewG!%=lO{6tZ8mVL!9-Pl1C4xv+T{yT&_~e-v_Ubko%mMb%=p{RMf%>8pA+RX zCOGjkDt{Zijew~&rj*!zllsH`W;}d+Dv9&MYg~CebL0;%k)_Ucg-2egU!Zn^|IZkI zKORlkL+5vhc(nbSC2yWy+dj<(LOPmw6Zrvad9t15&q`jIkn|K42- znAQkQI?LKD1K=p@gH4;!nnikBv$nY@*|@WDZ{rsHp87F~P04S5HBI|><0b(Nqre8j zE*UVP?MgHyy1**rPOM9Cj^ui6ckffxqmLc!#9i0XmO}MUx5wd)bBsiD#X`-Yhp|NenMvP0}xg^-XBQ^ z=zpX?KP5YUUlFF&>hKCUX`)$tUtjKz*p>3$l>S>q^m=*~0q_qMuSbw7?B0b<6Wave zhuFRzP9*jzO2c5+;jWgZdUE?G!JY;?*O~1G#z`~^zAv-A8||!w_z{W|jXwr9LC_ONqa z(X$=jccyK(82Hm72e%&A&Q=SBX@G<@8$5$$u9?iE!zZ8ca?nv+cVXvBe!&UUS_@We zd+22Uu|a$A_+i@?_Teuc?>{zt2qMFs${w_+)jE7q*hfZA4D^rKCx(Vj4UXAk{R1O| zfQT;TFnQ$$uGVm~PT&Z65VXXj-i!-7peWF;R&#)|Y-tCyRZ8}xQ_SXFVGo}GudG+J zs}(m3af`6a=5622moiS^Vu--Y*kuoV?eR*5V#>88a?MF3@7S?Je8^ZfA&km&A#lcl z2fpHyEEP+EYfKcY2Heiet{+Sbqi!0ndawY;%L)V5ueg3$NXJR1GAWGgbP<9o9{8%| zN%&dtO2uW04Q+5B4gyFj=L3ug9Ww%V3idb*psgh9X}FVg;wQ6!WsSPkKQP>eBU0kk zcu;tjA&pb;GSG(-M24Lq)2Yxv_BcP!FL@pC0-b*zpN!*ipyrwP@Vx%8To$#E8{V`XVHc@_ZuQxFs zpKHRzc;+7^eAowNlLAKLrX!AfSHYJF*c$QYyln1V1%H9?@Fca2qWHd5@RtaGdKp@s ztFY-P;mN(;w7ye`$Lkg)aDL`rSOxzm;rspz9^>%6BkX#C@KZD}_CO=ng2w~@{^9+N f$4(fZ-mkP9+)-+bsPHVaGX7OsP@IbZBC7ZgzH%3U literal 0 HcmV?d00001 diff --git a/exercises/asmprogram.s b/exercises/asmprogram.s new file mode 100644 index 0000000..1f952ac --- /dev/null +++ b/exercises/asmprogram.s @@ -0,0 +1,55 @@ + # Metadata for debuggers and other tools + .global main + .type main, @function + .extern printf + .extern scanf + + .section .text # Begins code and data +# Label that marks beginning of main function +main: + # Function stack setup + pushq %rbp + movq %rsp, %rbp + subq $128, %rsp # Reserve 128 bytes of stack space + + # Read an integer into -8(%rbp) + movq $scan_format, %rdi # 1st param: "%ld" + leaq -8(%rbp), %rsi # 2nd param: (%rbp - 8) + callq scanf + # If the input was invalid, jump to end + cmpq $1, %rax + jne .Lend + + # Read another integer into -16(%rbp) + movq $scan_format, %rdi + leaq -16(%rbp), %rsi + callq scanf + # If the input was invalid, jump to end + cmpq $1, %rax + jne .Lend + + + # Add the two integers together + movq -8(%rbp), %rax # Load first integer into %rax + addq -16(%rbp), %rax # Add second integer to %rax + + # Call function 'printf("%ld\n", %rsi)' + # to print the number in %rsi. + movq $print_format, %rdi + movq %rax, %rsi + callq printf + +# Labels starting with ".L" are local to this function, +# i.e. another function than "main" could have its own ".Lend". +.Lend: + # Return from main with status code 0 + movq $0, %rax + movq %rbp, %rsp + popq %rbp + ret + +# String data that we pass to functions 'scanf' and 'printf' +scan_format: + .asciz "%ld" +print_format: + .asciz "%ld\n" diff --git a/exercises/f(a * f(b)) + c.drawio b/exercises/f(a * f(b)) + c.drawio new file mode 100644 index 0000000..644a3e4 --- /dev/null +++ b/exercises/f(a * f(b)) + c.drawio @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/exercises/f(a + b, b + c).drawio b/exercises/f(a + b, b + c).drawio new file mode 100644 index 0000000..7687d11 --- /dev/null +++ b/exercises/f(a + b, b + c).drawio @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/exercises/f(f(a)).drawio b/exercises/f(f(a)).drawio new file mode 100644 index 0000000..d85fd25 --- /dev/null +++ b/exercises/f(f(a)).drawio @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/exercises/while i < 100 do i = i + 1.drawio b/exercises/while i < 100 do i = i + 1.drawio new file mode 100644 index 0000000..14258b3 --- /dev/null +++ b/exercises/while i < 100 do i = i + 1.drawio @@ -0,0 +1,61 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..b112b2c --- /dev/null +++ b/mypy.ini @@ -0,0 +1,3 @@ +[mypy] +disallow_untyped_defs = True +disallow_untyped_calls = True diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 0000000..75b67f1 --- /dev/null +++ b/poetry.lock @@ -0,0 +1,173 @@ +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. + +[[package]] +name = "autopep8" +version = "2.3.1" +description = "A tool that automatically formats Python code to conform to the PEP 8 style guide" +optional = false +python-versions = ">=3.8" +files = [ + {file = "autopep8-2.3.1-py2.py3-none-any.whl", hash = "sha256:a203fe0fcad7939987422140ab17a930f684763bf7335bdb6709991dd7ef6c2d"}, + {file = "autopep8-2.3.1.tar.gz", hash = "sha256:8d6c87eba648fdcfc83e29b788910b8643171c395d9c4bcf115ece035b9c9dda"}, +] + +[package.dependencies] +pycodestyle = ">=2.12.0" + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "mypy" +version = "1.13.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"}, + {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"}, + {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"}, + {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"}, + {file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"}, + {file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"}, + {file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"}, + {file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"}, + {file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"}, + {file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"}, + {file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"}, + {file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"}, + {file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"}, + {file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"}, + {file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"}, + {file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"}, + {file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"}, + {file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"}, + {file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"}, + {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"}, + {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +typing-extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + +[[package]] +name = "packaging" +version = "24.2" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, + {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, +] + +[[package]] +name = "pluggy" +version = "1.5.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "pycodestyle" +version = "2.12.1" +description = "Python style guide checker" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pycodestyle-2.12.1-py2.py3-none-any.whl", hash = "sha256:46f0fb92069a7c28ab7bb558f05bfc0110dac69a0cd23c61ea0040283a9d78b3"}, + {file = "pycodestyle-2.12.1.tar.gz", hash = "sha256:6838eae08bbce4f6accd5d5572075c63626a15ee3e6f842df996bf62f6d73521"}, +] + +[[package]] +name = "pytest" +version = "8.3.3" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2"}, + {file = "pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=1.5,<2" + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "typing-extensions" +version = "4.12.2" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, +] + +[metadata] +lock-version = "2.0" +python-versions = "^3.12" +content-hash = "b3bb39fa5b66f0bdb5a3c762a81c8fc01c30b0103c4ff37a063e4f2e32af0a6f" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..bed76a1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,31 @@ +[tool.poetry] +name = "compilers-project" +version = "0.0.0" +description = "" +authors = [] +readme = "README.md" +packages = [{include = "compiler", from = "src", format = ["sdist"]}] + +[tool.poetry.dependencies] +python = "^3.12" + +[tool.poetry.group.dev.dependencies] +autopep8 = "^2.3.1" +mypy = "^1.13.0" +pytest = "^8.3.3" + +[tool.poetry.scripts] +main = "compiler.__main__:main" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +pythonpath = "src" +addopts = [ + "--import-mode=importlib", +] + +[virtualenvs] +prefer-active-python = true diff --git a/src/compiler/__init__.py b/src/compiler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/compiler/__main__.py b/src/compiler/__main__.py new file mode 100644 index 0000000..bab6296 --- /dev/null +++ b/src/compiler/__main__.py @@ -0,0 +1,117 @@ +from base64 import b64encode +import json +import re +import sys +from socketserver import ForkingTCPServer, StreamRequestHandler +from traceback import format_exception +from typing import Any +from compiler.tokenizer import tokenize +from compiler.parser import parse +from compiler.type_checker import typecheck_module +from compiler.ir_generator import generate_ir_from_module +from compiler.assembly_generator import generate_assembly_from_dict +from compiler.assembler import assemble_and_get_executable + + +def call_compiler(source_code: str, input_file_name: str) -> bytes: + # *** TODO *** + # Call your compiler here and return the compiled executable. + # Raise an exception on compilation error. + # + # The input file name is informational only: you can optionally include in your source locations and error messages, + # or you can ignore it. + # *** TODO *** + tokens = tokenize(source_code) + ast = parse(tokens) + typecheck_module(ast) + instructions = generate_ir_from_module(ast) + assembly_code = generate_assembly_from_dict(instructions) + print(assembly_code) + return assemble_and_get_executable(assembly_code) + + +def main() -> int: + # === Option parsing === + command: str | None = None + input_file: str | None = None + output_file: str | None = None + host = "127.0.0.1" + port = 3000 + for arg in sys.argv[1:]: + if (m := re.fullmatch(r'--output=(.+)', arg)) is not None: + output_file = m[1] + elif (m := re.fullmatch(r'--host=(.+)', arg)) is not None: + host = m[1] + elif (m := re.fullmatch(r'--port=(.+)', arg)) is not None: + port = int(m[1]) + elif arg.startswith('-'): + raise Exception(f"Unknown argument: {arg}") + elif command is None: + command = arg + elif input_file is None: + input_file = arg + else: + raise Exception("Multiple input files not supported") + + if command is None: + print(f"Error: command argument missing", file=sys.stderr) + return 1 + + def read_source_code() -> str: + if input_file is not None: + with open(input_file) as f: + return f.read() + else: + return sys.stdin.read() + + # === Command implementations === + + if command == 'compile': + source_code = read_source_code() + if output_file is None: + raise Exception("Output file flag --output=... required") + executable = call_compiler(source_code, input_file or '(source code)') + with open(output_file, 'wb') as f: + f.write(executable) + elif command == 'serve': + try: + run_server(host, port) + except KeyboardInterrupt: + pass + else: + print(f"Error: unknown command: {command}", file=sys.stderr) + return 1 + return 0 + + +def run_server(host: str, port: int) -> None: + class Server(ForkingTCPServer): + allow_reuse_address = True + request_queue_size = 32 + + class Handler(StreamRequestHandler): + def handle(self) -> None: + result: dict[str, Any] = {} + try: + input_str = self.rfile.read().decode() + input = json.loads(input_str) + if input["command"] == "compile": + source_code = input["code"] + executable = call_compiler(source_code, "(source code)") + result["program"] = b64encode(executable).decode() + elif input["command"] == "ping": + pass + else: + result["error"] = "Unknown command: " + input['command'] + except Exception as e: + result["error"] = "".join(format_exception(e)) + result_str = json.dumps(result) + self.request.sendall(str.encode(result_str)) + + print(f"Starting TCP server at {host}:{port}") + with Server((host, port), Handler) as server: + server.serve_forever() + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/src/compiler/assembler.py b/src/compiler/assembler.py new file mode 100644 index 0000000..ffd9493 --- /dev/null +++ b/src/compiler/assembler.py @@ -0,0 +1,373 @@ +import subprocess +import tempfile +from contextlib import nullcontext +from os import path +from typing import Any, Callable, ContextManager, TypeVar +import shutil +from pathlib import Path + +T = TypeVar('T') + + +def assemble( + assembly_code: str, + output_file: str, + workdir: str | None = None, + tempfile_basename: str = 'program', + link_with_c: bool = False, + extra_libraries: list[str] = [], +) -> None: + """Invokes 'as' and 'ld' to generate an executable file from Assembly code. + + The file is written to the given path. + """ + _assemble( + assembly_code=assembly_code, + workdir=workdir, + tempfile_basename=tempfile_basename, + link_with_c=link_with_c, + extra_libraries=extra_libraries, + take_output=lambda f: shutil.move(f, output_file) + ) + + +def assemble_and_get_executable( + assembly_code: str, + workdir: str | None = None, + tempfile_basename: str = 'program', + link_with_c: bool = False, + extra_libraries: list[str] = [], +) -> bytes: + """Invokes 'as' and 'ld' to generate an executable file from Assembly code. + + The file is returned. + """ + return _assemble( + assembly_code=assembly_code, + workdir=workdir, + tempfile_basename=tempfile_basename, + link_with_c=link_with_c, + extra_libraries=extra_libraries, + take_output=lambda f: Path(f).read_bytes() + ) + + +def _assemble( + assembly_code: str, + workdir: str | None, + tempfile_basename: str, + link_with_c: bool, + extra_libraries: list[str], + take_output: Callable[[str], T], +) -> T: + if workdir is not None: + wd = Path(workdir).absolute().as_posix() + return _assemble_impl(assembly_code, wd, tempfile_basename, link_with_c, extra_libraries, take_output) + else: + with tempfile.TemporaryDirectory(prefix='compiler_') as wd: + return _assemble_impl(assembly_code, wd, tempfile_basename, link_with_c, extra_libraries, take_output) + + +def _assemble_impl( + assembly_code: str, + workdir: str, + tempfile_basename: str, + link_with_c: bool, + extra_libraries: list[str], + take_output: Callable[[str], T], +) -> T: + stdlib_asm = path.join(workdir, 'stdlib.s') + stdlib_obj = path.join(workdir, 'stdlib.o') + program_asm = path.join(workdir, f'{tempfile_basename}.s') + program_obj = path.join(workdir, f'{tempfile_basename}.o') + output_file = path.join(workdir, 'a.out') + + if link_with_c: + final_stdlib_asm_code = drop_start_symbol(stdlib_asm_code) + else: + final_stdlib_asm_code = stdlib_asm_code + + with open(stdlib_asm, 'w') as f: + f.write(final_stdlib_asm_code) + with open(program_asm, 'w') as f: + f.write(assembly_code) + subprocess.run(['as', '-g', '-o' + + stdlib_obj, stdlib_asm], check=True) + subprocess.run(['as', '-g', '-o' + + program_obj, program_asm], check=True) + linker_flags = ['-static', *[f'-l{lib}' for lib in extra_libraries]] + if link_with_c: + # Linking with the C standard library correctly is complicated, + # as evidenced by the complicated linker command shown by `cc -v something.c`. + # Instead of trying to build the right `ld` command ourselves, we use the C compiler + # to do the linking. + subprocess.run( + ['cc', '-o' + output_file, *linker_flags, stdlib_obj, program_obj], check=True) + else: + subprocess.run( + ['ld', '-o' + output_file, *linker_flags, stdlib_obj, program_obj], check=True) + return take_output(output_file) + + +def drop_start_symbol(code: str) -> str: + return code.split('# BEGIN START')[0] + code.split('# END START')[1] + + +# WARNING: if you want to copy this into a separate file, +# replace all double backslashes `\\` with a single backslash `\`. +stdlib_asm_code: str = """ + .global _start + .global print_int + .global print_bool + .global read_int + .extern main + .section .text + +# BEGIN START (we skip this part when linking with C) +# ***** Function '_start' ***** +# Calls function 'main' and halts the program + +_start: + call main + movq $60, %rax + xorq %rdi, %rdi + syscall +# END START + +# ***** Function 'print_int' ***** +# Prints a 64-bit signed integer followed by a newline. +# +# We'll build up the digits to print on the stack. +# We generate the least significant digit first, +# and the stack grows downward, so that works out nicely. +# +# Algorithm: +# push(newline) +# if x < 0: +# negative = true +# x = -x +# while x > 0: +# push(digit for (x % 10)) +# x = x / 10 +# if negative: +# push(minus sign) +# syscall 'write' with pushed data +# return the original argument +# +# Registers: +# - rdi = our input number, which we divide down as we go +# - rsp = stack pointer, pointing to the next character to emit. +# - rbp = pointer to one after the last byte of our output (which grows downward) +# - r9 = whether the number was negative +# - r10 = a copy of the original input, so we can return it +# - rax, rcx and rdx are used by intermediate computations + +print_int: + pushq %rbp # Save previous stack frame pointer + movq %rsp, %rbp # Set stack frame pointer + movq %rdi, %r10 # Back up original input + decq %rsp # Point rsp at first byte of output + # TODO: this non-alignment confuses debuggers. Use a different register? + + # Add newline as the last output byte + movb $10, (%rsp) # ASCII newline = 10 + decq %rsp + + # Check for zero and negative cases + xorq %r9, %r9 + xorq %rax, %rax + cmpq $0, %rdi + je .Ljust_zero + jge .Ldigit_loop + incq %r9 # If < 0, set %r9 to 1 + +.Ldigit_loop: + cmpq $0, %rdi + je .Ldigits_done # Loop done when input = 0 + + # Divide rdi by 10 + movq %rdi, %rax + movq $10, %rcx + cqto + idivq %rcx # Sets rax = quotient and rdx = remainder + + movq %rax, %rdi # The quotient becomes our remaining input + cmpq $0, %rdx # If the remainder is negative (because the input is), negate it + jge .Lnot_negative + negq %rdx +.Lnot_negative: + addq $48, %rdx # ASCII '0' = 48. Add the remainder to get the correct digit. + movb %dl, (%rsp) # Store the digit in the output + decq %rsp + jmp .Ldigit_loop + +.Ljust_zero: + movb $48, (%rsp) # ASCII '0' = 48 + decq %rsp + +.Ldigits_done: + + # Add minus sign if negative + cmpq $0, %r9 + je .Lminus_done + movb $45, (%rsp) # ASCII '-' = 45 + decq %rsp +.Lminus_done: + + # Call syscall 'write' + movq $1, %rax # rax = syscall number for write + movq $1, %rdi # rdi = file handle for stdout + # rsi = pointer to message + movq %rsp, %rsi + incq %rsi + # rdx = number of bytes + movq %rbp, %rdx + subq %rsp, %rdx + decq %rdx + syscall + + # Restore stack registers and return the original input + movq %rbp, %rsp + popq %rbp + movq %r10, %rax + ret + + +# ***** Function 'print_bool' ***** +# Prints either 'true' or 'false', followed by a newline. +print_bool: + pushq %rbp # Save previous stack frame pointer + movq %rsp, %rbp # Set stack frame pointer + movq %rdi, %r10 # Back up original input + + cmpq $0, %rdi # See if the argument is false (i.e. 0) + jne .Ltrue + movq $false_str, %rsi # If so, set %rsi to the address of the string for false + movq $false_str_len, %rdx # and %rdx to the length of that string, + jmp .Lwrite +.Ltrue: + movq $true_str, %rsi # otherwise do the same with the string for true. + movq $true_str_len, %rdx + +.Lwrite: + # Call syscall 'write' + movq $1, %rax # rax = syscall number for write + movq $1, %rdi # rdi = file handle for stdout + # rsi = pointer to message (already set above) + # rdx = number of bytes (already set above) + syscall + + # Restore stack registers and return the original input + movq %rbp, %rsp + popq %rbp + movq %r10, %rax + ret + +true_str: + .ascii "true\\n" +true_str_len = . - true_str +false_str: + .ascii "false\\n" +false_str_len = . - false_str + +# ***** Function 'read_int' ***** +# Reads an integer from stdin, skipping non-digit characters, until a newline. +# +# To avoid the complexity of buffering, it very inefficiently +# makes a syscall to read each byte. +# +# It crashes the program if input could not be read. +read_int: + pushq %rbp # Save previous stack frame pointer + movq %rsp, %rbp # Set stack frame pointer + pushq %r12 # Back up r12 since it's callee-saved + pushq $0 # Reserve space for input + # (we only write the lowest byte, + # but loading 64-bits at once is easier) + + xorq %r9, %r9 # Clear r9 - it'll store the minus sign + xorq %r10, %r10 # Clear r10 - it'll accumulate our output + # Skip r11 - syscalls destroy it + xorq %r12, %r12 # Clear r12 - it'll count the number of input bytes read. + + # Loop until a newline or end of input is encountered +.Lloop: + # Call syscall 'read' + xorq %rax, %rax # syscall number for read = 0 + xorq %rdi, %rdi # file handle for stdin = 0 + movq %rsp, %rsi # rsi = pointer to buffer + movq $1, %rdx # rdx = buffer size + syscall # result in rax = number of bytes read, + # or 0 on end of input, -1 on error + + # Check return value: either -1, 0 or 1. + cmpq $0, %rax + jg .Lno_error + je .Lend_of_input + jmp .Lerror + +.Lend_of_input: + cmpq $0, %r12 + je .Lerror # If we've read no input, it's an error. + jmp .Lend # Otherwise complete reading this input. + +.Lno_error: + incq %r12 # Increment input byte counter + movq (%rsp), %r8 # Load input byte to r8 + + # If the input byte is 10 (newline), exit the loop + cmpq $10, %r8 + je .Lend + + # If the input byte is 45 (minus sign), negate r9 + cmpq $45, %r8 + jne .Lnegation_done + xorq $1, %r9 +.Lnegation_done: + + # If the input byte is not between 48 ('0') and 57 ('9') + # then skip it as a junk character. + cmpq $48, %r8 + jl .Lloop + cmpq $57, %r8 + jg .Lloop + + # Subtract 48 to get a digit 0..9 + subq $48, %r8 + + # Shift the digit onto the result + imulq $10, %r10 + addq %r8, %r10 + + jmp .Lloop + +.Lend: + # If it's a negative number, negate the result + cmpq $0, %r9 + je .Lfinal_negation_done + neg %r10 +.Lfinal_negation_done: + # Restore stack registers and return the result + popq %r12 + movq %rbp, %rsp + popq %rbp + movq %r10, %rax + ret + +.Lerror: + # Write error message to stderr with syscall 'write' + movq $1, %rax + movq $2, %rdi + movq $read_int_error_str, %rsi + movq $read_int_error_str_len, %rdx + syscall + + # Exit the program + movq $60, %rax # Syscall number for exit = 60. + movq $1, %rdi # Set exit code 1. + syscall + +read_int_error_str: + .ascii "Error: read_int() failed to read input\\n" +read_int_error_str_len = . - read_int_error_str +""" diff --git a/src/compiler/assembly_generator.py b/src/compiler/assembly_generator.py new file mode 100644 index 0000000..44a92cb --- /dev/null +++ b/src/compiler/assembly_generator.py @@ -0,0 +1,374 @@ +import uuid +from dataclasses import fields +from compiler.ir import IRVar, Instruction, LoadIntConst, LoadBoolConst, Label, CondJump, Jump, Copy, Call, Return, LoadCustomFuncArgs +from compiler.ir_generator import SymTab +from compiler.intrinsics import all_intrinsics, IntrinsicArgs + +class Locals: + """Knows the memory location of every local variable.""" + _var_to_location: dict[IRVar, str] + _stack_used: int + + def __init__(self, variables: list[IRVar]) -> None: + curr_loc = -8 + self._var_to_location = {} + for ir_var in variables: + self._var_to_location[ir_var] = f'{curr_loc}(%rbp)' + curr_loc -= 8 + self._stack_used = 8 * len(variables) + + def get_ref(self, v: IRVar) -> str: + """Returns an Assembly reference like `-24(%rbp)` + for the memory location that stores the given variable""" + return self._var_to_location[v] + + def stack_used(self) -> int: + """Returns the number of bytes of stack space needed for the local variables.""" + return self._stack_used + +def get_all_ir_variables(instructions: list[Instruction]) -> list[IRVar]: + result_list: list[IRVar] = [] + result_set: set[IRVar] = set() + + def add(v: IRVar) -> None: + if v not in result_set: + result_list.append(v) + result_set.add(v) + + for insn in instructions: + for field in fields(insn): + value = getattr(insn, field.name) + if isinstance(value, IRVar): + add(value) + elif isinstance(value, list): + for v in value: + if isinstance(v, IRVar): + add(v) + + return result_list + +def get_all_labels(instructions: list[Instruction]) -> list[Label]: + result_list: list[Label] = [] + + for ins in instructions: + if isinstance(ins, Label): + result_list.append(ins) + elif isinstance(ins, Jump): + result_list.append(ins.label) + elif isinstance(ins, CondJump): + result_list.append(ins.then_label) + result_list.append(ins.else_label) + + return result_list + +def generate_assembly(instructions: list[Instruction]) -> list[str]: + lines = [] + def emit(line: str) -> None: lines.append(line) + + locals = Locals( + variables=get_all_ir_variables(instructions) + ) + + # ... Emit initial declarations and stack setup here ... + emit('.global main') + emit('.type main, @function') + + emit('main:') + + emit('pushq %rbp') + emit('movq %rsp, %rbp') + emit(f'subq ${locals.stack_used()}, %rsp') + + for insn in instructions: + emit('# ' + str(insn)) + match insn: + + case Label(): + emit("") + # ".L" prefix marks the symbol as "private". + # This makes GDB backtraces look nicer too: + # https://stackoverflow.com/a/26065570/965979 + emit(f'.L{insn.name}:') + + case LoadIntConst(): + if -2**31 <= insn.value < 2**31: + emit(f'movq ${insn.value}, {locals.get_ref(insn.dest)}') + else: + # Due to a quirk of x86-64, we must use + # a different instruction for large integers. + # It can only write to a register, + # not a memory location, so we use %rax + # as a temporary. + emit(f'movabsq ${insn.value}, %rax') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + case Jump(): + emit(f'jmp .L{insn.label.name}') + + case LoadBoolConst(): + if insn.value: + emit(f'movq $1, {locals.get_ref(insn.dest)}') + else: + emit(f'movq $0, {locals.get_ref(insn.dest)}') + + case Copy(): + emit(f'movq {locals.get_ref(insn.source)}, %rax') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + case CondJump(): + emit(f'cmpq $0, {locals.get_ref(insn.cond)}') + emit(f'jne .L{insn.then_label.name}') + emit(f'jmp .L{insn.else_label.name}') + + case Call(): + + if insn.fun.name in list(all_intrinsics.keys()): + all_intrinsics[insn.fun.name](IntrinsicArgs( + arg_refs=[locals.get_ref(arg) for arg in insn.args], + result_register='%rax', + emit=emit + )) + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + else: + + if len(insn.args) == 0: + emit(f'callq {insn.fun.name}') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + elif len(insn.args) == 1: + emit(f'movq {locals.get_ref(insn.args[0])}, %rdi') + emit(f'callq {insn.fun.name}') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + elif len(insn.args) == 2: + emit(f'movq {locals.get_ref(insn.args[0])}, %rdi') + emit(f'movq {locals.get_ref(insn.args[1])}, %rsi') + emit(f'callq {insn.fun.name}') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + elif len(insn.args) == 3: + emit(f'movq {locals.get_ref(insn.args[0])}, %rdi') + emit(f'movq {locals.get_ref(insn.args[1])}, %rsi') + emit(f'movq {locals.get_ref(insn.args[2])}, %rdx') + emit(f'callq {insn.fun.name}') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + elif len(insn.args) == 4: + emit(f'movq {locals.get_ref(insn.args[0])}, %rdi') + emit(f'movq {locals.get_ref(insn.args[1])}, %rsi') + emit(f'movq {locals.get_ref(insn.args[2])}, %rdx') + emit(f'movq {locals.get_ref(insn.args[3])}, %rcx') + emit(f'callq {insn.fun.name}') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + elif len(insn.args) == 5: + emit(f'movq {locals.get_ref(insn.args[0])}, %rdi') + emit(f'movq {locals.get_ref(insn.args[1])}, %rsi') + emit(f'movq {locals.get_ref(insn.args[2])}, %rdx') + emit(f'movq {locals.get_ref(insn.args[3])}, %rcx') + emit(f'movq {locals.get_ref(insn.args[4])}, %r8') + emit(f'callq {insn.fun.name}') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + elif len(insn.args) == 6: + emit(f'movq {locals.get_ref(insn.args[0])}, %rdi') + emit(f'movq {locals.get_ref(insn.args[1])}, %rsi') + emit(f'movq {locals.get_ref(insn.args[2])}, %rdx') + emit(f'movq {locals.get_ref(insn.args[3])}, %rcx') + emit(f'movq {locals.get_ref(insn.args[4])}, %r8') + emit(f'movq {locals.get_ref(insn.args[5])}, %r9') + emit(f'callq {insn.fun.name}') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + elif len(insn.args) > 6: + emit(f'movq {locals.get_ref(insn.args[0])}, %rdi') + emit(f'movq {locals.get_ref(insn.args[1])}, %rsi') + emit(f'movq {locals.get_ref(insn.args[2])}, %rdx') + emit(f'movq {locals.get_ref(insn.args[3])}, %rcx') + emit(f'movq {locals.get_ref(insn.args[4])}, %r8') + emit(f'movq {locals.get_ref(insn.args[5])}, %r9') + + for i in range(len(insn.args) - 1, 5, -1): + emit(f'pushq {locals.get_ref(insn.args[i])}') + + emit(f'callq {insn.fun.name}') + emit(f'addq ${8 * (len(insn.args) - 6)}, %rsp') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + case Return(): + pass + + + + emit(f'movq %rbp, %rsp') + emit(f'popq %rbp') + emit(f'ret') + + return lines + +def generate_assembly_from_dict(instruction_dict: dict[str, list[Instruction]]) -> str: + lines = [] + def emit(line: str) -> None: lines.append(line) + + emit('.extern print_int') + emit('.extern print_bool') + emit('.extern read_int') + + emit('.section .text') + + + for k, v in instruction_dict.items(): + + if k != 'main': + + labels = get_all_labels(v) + variables=get_all_ir_variables(v) + + locals = Locals(variables) + + st = SymTab[str]() + + for label in labels: + label_uuid = str(uuid.uuid4()).replace('-', '') + st.assign(label.name, label_uuid) + + emit(f'.global {k}') + emit(f'.type {k}, @function') + + emit(f'{k}:') + + emit('pushq %rbp') + emit('movq %rsp, %rbp') + emit(f'subq ${locals.stack_used()}, %rsp') + + for insn in v: + emit('# ' + str(insn)) + match insn: + + case LoadCustomFuncArgs(): + for i in range(len(insn.args)): + reg = ['%rdi', '%rsi', '%rdx', '%rcx', '%r8', '%r9'][i] + var = insn.args[i] + emit(f'movq {reg}, {locals.get_ref(var)}') + + case Label(): + emit("") + label_uuid = st.lookup(insn.name) + emit(f'.L{label_uuid}:') + + case LoadIntConst(): + if -2**31 <= insn.value < 2**31: + emit(f'movq ${insn.value}, {locals.get_ref(insn.dest)}') + else: + emit(f'movabsq ${insn.value}, %rax') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + case Jump(): + label_uuid = st.lookup(insn.label.name) + emit(f'jmp .L{label_uuid}') + + case LoadBoolConst(): + if insn.value: + emit(f'movq $1, {locals.get_ref(insn.dest)}') + else: + emit(f'movq $0, {locals.get_ref(insn.dest)}') + + case Copy(): + emit(f'movq {locals.get_ref(insn.source)}, %rax') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + case CondJump(): + emit(f'cmpq $0, {locals.get_ref(insn.cond)}') + then_label_uuid = st.lookup(insn.then_label.name) + emit(f'jne .L{then_label_uuid}') + else_label_uuid = st.lookup(insn.else_label.name) + emit(f'jmp .L{else_label_uuid}') + + case Call(): + + if insn.fun.name in list(all_intrinsics.keys()): + all_intrinsics[insn.fun.name](IntrinsicArgs( + arg_refs=[locals.get_ref(arg) for arg in insn.args], + result_register='%rax', + emit=emit + )) + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + else: + + if len(insn.args) == 0: + emit(f'callq {insn.fun.name}') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + elif len(insn.args) == 1: + emit(f'movq {locals.get_ref(insn.args[0])}, %rdi') + emit(f'callq {insn.fun.name}') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + elif len(insn.args) == 2: + emit(f'movq {locals.get_ref(insn.args[0])}, %rdi') + emit(f'movq {locals.get_ref(insn.args[1])}, %rsi') + emit(f'callq {insn.fun.name}') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + elif len(insn.args) == 3: + emit(f'movq {locals.get_ref(insn.args[0])}, %rdi') + emit(f'movq {locals.get_ref(insn.args[1])}, %rsi') + emit(f'movq {locals.get_ref(insn.args[2])}, %rdx') + emit(f'callq {insn.fun.name}') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + elif len(insn.args) == 4: + emit(f'movq {locals.get_ref(insn.args[0])}, %rdi') + emit(f'movq {locals.get_ref(insn.args[1])}, %rsi') + emit(f'movq {locals.get_ref(insn.args[2])}, %rdx') + emit(f'movq {locals.get_ref(insn.args[3])}, %rcx') + emit(f'callq {insn.fun.name}') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + elif len(insn.args) == 5: + emit(f'movq {locals.get_ref(insn.args[0])}, %rdi') + emit(f'movq {locals.get_ref(insn.args[1])}, %rsi') + emit(f'movq {locals.get_ref(insn.args[2])}, %rdx') + emit(f'movq {locals.get_ref(insn.args[3])}, %rcx') + emit(f'movq {locals.get_ref(insn.args[4])}, %r8') + emit(f'callq {insn.fun.name}') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + elif len(insn.args) == 6: + emit(f'movq {locals.get_ref(insn.args[0])}, %rdi') + emit(f'movq {locals.get_ref(insn.args[1])}, %rsi') + emit(f'movq {locals.get_ref(insn.args[2])}, %rdx') + emit(f'movq {locals.get_ref(insn.args[3])}, %rcx') + emit(f'movq {locals.get_ref(insn.args[4])}, %r8') + emit(f'movq {locals.get_ref(insn.args[5])}, %r9') + emit(f'callq {insn.fun.name}') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + elif len(insn.args) > 6: + emit(f'movq {locals.get_ref(insn.args[0])}, %rdi') + emit(f'movq {locals.get_ref(insn.args[1])}, %rsi') + emit(f'movq {locals.get_ref(insn.args[2])}, %rdx') + emit(f'movq {locals.get_ref(insn.args[3])}, %rcx') + emit(f'movq {locals.get_ref(insn.args[4])}, %r8') + emit(f'movq {locals.get_ref(insn.args[5])}, %r9') + + for i in range(len(insn.args) - 1, 5, -1): + emit(f'pushq {locals.get_ref(insn.args[i])}') + + emit(f'callq {insn.fun.name}') + emit(f'addq ${8 * (len(insn.args) - 6)}, %rsp') + emit(f'movq %rax, {locals.get_ref(insn.dest)}') + + case Return(): + emit(f'movq {locals.get_ref(insn.result)}, %rax') + emit(f'movq %rbp, %rsp') + emit(f'popq %rbp') + emit(f'ret') + + + main_lines: list[str] = generate_assembly(instruction_dict['main']) + lines.extend(main_lines) + + return '\n'.join(lines) diff --git a/src/compiler/ast.py b/src/compiler/ast.py new file mode 100644 index 0000000..73d2957 --- /dev/null +++ b/src/compiler/ast.py @@ -0,0 +1,85 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from compiler.tokenizer import Location +from compiler.types import Type, Unit_Instance + +@dataclass +class Expression: + """Base class for AST nodes representing expressions.""" + location: Location + type: Type = field(kw_only=True, default=Unit_Instance) + +@dataclass +class Literal(Expression): + value: int | bool | None + +@dataclass +class Identifier(Expression): + name: str + +@dataclass +class BinaryOp(Expression): + """AST node for a binary operation like `A + B`""" + left: Expression + op: str + right: Expression + +@dataclass +class UnaryOp(Expression): + op: str + right: Expression + +@dataclass +class If(Expression): + cond_exp: Expression + then_exp: Expression + else_exp: Expression | None + +@dataclass +class While(Expression): + while_exp: Expression + do_exp: Expression + +@dataclass +class FunctionArg(): + name: str + type: Type + +@dataclass +class FunctionDef(): + name: str + args: list[FunctionArg] + return_type: Type + block: Block + +@dataclass +class Module(): + function_defs: list[FunctionDef] + block: Block + +@dataclass +class FunctionCall(Expression): + name: str + args: list[Expression] + +@dataclass +class Block(Expression): + statements: list[Expression] + +@dataclass +class Var(Expression): + name: str + value: Expression + type_f: Type | None + +@dataclass +class Break(Expression): + pass + +@dataclass +class Continue(Expression): + pass + +@dataclass +class Return(Expression): + result: Expression diff --git a/src/compiler/interpreter.py b/src/compiler/interpreter.py new file mode 100644 index 0000000..4addd70 --- /dev/null +++ b/src/compiler/interpreter.py @@ -0,0 +1,311 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Any +from compiler import ast +from compiler.tokenizer import L +from types import FunctionType + +def unary_op_negate_int(node: ast.UnaryOp, sym_tab: SymTab) -> int: + right: Any = interpret(node.right, sym_tab) + if type(right) is int: + return -right + else: + raise Exception() + +def unary_op_negate_bool(node: ast.UnaryOp, sym_tab: SymTab) -> bool: + right: Any = interpret(node.right, sym_tab) + if type(right) is bool: + return not right + else: + raise Exception() + +def binary_op_and(node: ast.BinaryOp, sym_tab: SymTab) -> bool: + a: Any = interpret(node.left, sym_tab) + if not a: + return False + b: Any = interpret(node.right, sym_tab) + return b + +def binary_op_or(node: ast.BinaryOp, sym_tab: SymTab) -> bool: + a: Any = interpret(node.left, sym_tab) + if a: + return True + b: Any = interpret(node.right, sym_tab) + return b + +def binary_op_add(node: ast.BinaryOp, sym_tab: SymTab) -> int: + a: Any = interpret(node.left, sym_tab) + b: Any = interpret(node.right, sym_tab) + return a + b + +def binary_op_subtract(node: ast.BinaryOp, sym_tab: SymTab) -> int: + a: Any = interpret(node.left, sym_tab) + b: Any = interpret(node.right, sym_tab) + return a - b + +def binary_op_multiply(node: ast.BinaryOp, sym_tab: SymTab) -> int: + a: Any = interpret(node.left, sym_tab) + b: Any = interpret(node.right, sym_tab) + return a * b + +def binary_op_divide(node: ast.BinaryOp, sym_tab: SymTab) -> int: + a: Any = interpret(node.left, sym_tab) + b: Any = interpret(node.right, sym_tab) + return int(a / b) + +def binary_op_modulo(node: ast.BinaryOp, sym_tab: SymTab) -> int: + a: Any = interpret(node.left, sym_tab) + b: Any = interpret(node.right, sym_tab) + return a % b + +def binary_op_lt(node: ast.BinaryOp, sym_tab: SymTab) -> bool: + a: Any = interpret(node.left, sym_tab) + b: Any = interpret(node.right, sym_tab) + return a < b + +def binary_op_gt(node: ast.BinaryOp, sym_tab: SymTab) -> bool: + a: Any = interpret(node.left, sym_tab) + b: Any = interpret(node.right, sym_tab) + return a > b + +def binary_op_gte(node: ast.BinaryOp, sym_tab: SymTab) -> bool: + a: Any = interpret(node.left, sym_tab) + b: Any = interpret(node.right, sym_tab) + return a >= b + +def binary_op_lte(node: ast.BinaryOp, sym_tab: SymTab) -> bool: + a: Any = interpret(node.left, sym_tab) + b: Any = interpret(node.right, sym_tab) + return a <= b + +def binary_op_diff(node: ast.BinaryOp, sym_tab: SymTab) -> bool: + a: Any = interpret(node.left, sym_tab) + b: Any = interpret(node.right, sym_tab) + return a != b + +def binary_op_eq(node: ast.BinaryOp, sym_tab: SymTab) -> bool: + a: Any = interpret(node.left, sym_tab) + b: Any = interpret(node.right, sym_tab) + return a == b + +def built_in_function_print_int(x: int) -> None: + print(x) + +def built_in_function_print_bool(x: bool) -> None: + if x: + print('true') + else: + print('false') + +def built_in_function_read_int() -> int: + return int(input()) + +class SymTab: + locals: dict[str, dict[str, FunctionType] | Any] + parent: SymTab | None + + def __init__(self, locals: dict = {}, parent: SymTab | None = None) -> None: + self.locals = locals + self.parent = parent + if parent is None: + self.locals.update({ + 'binary_ops': { + 'and': binary_op_and, + 'or' : binary_op_or, + '+': binary_op_add, + '-': binary_op_subtract, + '*': binary_op_multiply, + '/': binary_op_divide, + '%': binary_op_modulo, + '<': binary_op_lt, + '>': binary_op_gt, + '>=': binary_op_gte, + '<=': binary_op_lte, + '!=': binary_op_diff, + '==': binary_op_eq, + }, + 'unary_ops': { + '-': unary_op_negate_int, + 'not': unary_op_negate_bool + }, + 'built_in_functions': { + 'print_int': built_in_function_print_int, + 'print_bool': built_in_function_print_bool, + 'read_int': built_in_function_read_int + } + }) + + def get_top_level(self) -> SymTab: + if self.parent is None: + return self + return self.parent.get_top_level() + + def assign(self, var_name: str, var_value: Any) -> None: + self.locals[var_name] = var_value + + def assign_recursive(self, var_name: str, var_value: Any) -> None: + if self.locals.get(var_name) is not None: + self.locals[var_name] = var_value + else: + if self.parent is not None: + self.parent.assign_recursive(var_name, var_value) + else: + raise Exception() + + def lookup(self, var_name: str) -> Any: + return self.locals.get(var_name) + + def lookup_recursive(self, var_name: str) -> Any: + if self.locals.get(var_name) is not None: + return self.locals.get(var_name) + else: + if self.parent is not None: + return self.parent.lookup_recursive(var_name) + else: + raise Exception() + +@dataclass +class Unit: + pass + +type Value = int | bool | Unit | FunctionType | None + +def interpret(node: ast.Expression | None, sym_tab: SymTab | None = None) -> Value: + + if sym_tab is None: + sym_tab = SymTab() + + match node: + + case None: + return None + + case ast.Literal(): + if node == ast.Literal(location=L, value=None): + return Unit() + else: + return node.value + + case ast.Identifier(): + return sym_tab.lookup_recursive(node.name) + + case ast.UnaryOp(): + + if sym_tab.get_top_level().locals['unary_ops'].get(node.op) is not None: + to_be_executed_function = sym_tab.get_top_level().locals['unary_ops'].get(node.op) + assert isinstance(to_be_executed_function, FunctionType) + return to_be_executed_function(node, sym_tab) + else: + raise Exception() + + case ast.BinaryOp(): + + if node.op == '=': + a: Any = node.left + b: Any = interpret(node.right, sym_tab) + + if isinstance(a, ast.Identifier): + sym_tab.assign_recursive(a.name, b) + else: + raise Exception() + + return b + + + else: + + if sym_tab.get_top_level().locals['binary_ops'].get(node.op) is not None: + to_be_executed_function = sym_tab.get_top_level().locals['binary_ops'].get(node.op) + assert isinstance(to_be_executed_function, FunctionType) + return to_be_executed_function(node, sym_tab) + else: + raise Exception() + + case ast.If(): + + e1 = node.cond_exp + e2 = node.then_exp + e3 = node.else_exp + + # if-then + if e3 is None: + e1_value = interpret(e1, sym_tab) + if e1_value == True: + interpret(e2, sym_tab) + return Unit() + elif e1_value == False: + return Unit() + else: + raise Exception() + + # if-then-else + else: + e1_value = interpret(e1, sym_tab) + if e1_value == True: + return interpret(e2, sym_tab) + elif e1_value == False: + return interpret(e3, sym_tab) + else: + raise Exception() + + case ast.While(): + + while interpret(node.while_exp, sym_tab): + interpret(node.do_exp, sym_tab) + + if not interpret(node.while_exp, sym_tab): + return Unit() + else: + raise Exception() + + case ast.FunctionCall(): + if node.name == 'print_int': + if len(node.args) != 1: + raise Exception() + + arg_value = interpret(node.args[0], sym_tab) + sym_tab.get_top_level().locals['built_in_functions']['print_int'](arg_value) + return Unit() + + elif node.name == 'print_bool': + if len(node.args) != 1: + raise Exception() + + arg_value = interpret(node.args[0], sym_tab) + sym_tab.get_top_level().locals['built_in_functions']['print_bool'](arg_value) + return Unit() + + elif node.name == 'read_int': + return sym_tab.get_top_level().locals['built_in_functions']['read_int']() + + else: + # User-defined functions + interpreted_args = [] + for arg in node.args: + interpreted_args.append(interpret(arg, sym_tab)) + return sym_tab.lookup_recursive(node.name)(*interpreted_args) + + case ast.Block(): + + if node.statements == []: + return Unit() + + new_sym_tab = SymTab(locals={}, parent=sym_tab) + for index, statement in enumerate(node.statements): + statement_value = interpret(statement, new_sym_tab) + if index == len(node.statements) - 1: + if statement_value != ast.Literal(location=L, value=None): + return statement_value + else: + return Unit() + + case ast.Var(): + + if sym_tab.lookup(node.name) is not None: + raise Exception() + + sym_tab.assign(node.name, interpret(node.value)) + + return Unit() + + return -1 # This line is only added for the linter; it should never be actually reached. diff --git a/src/compiler/intrinsics.py b/src/compiler/intrinsics.py new file mode 100644 index 0000000..067ed70 --- /dev/null +++ b/src/compiler/intrinsics.py @@ -0,0 +1,113 @@ +from dataclasses import dataclass +from typing import Callable + + +@dataclass +class IntrinsicArgs(): + arg_refs: list[str] + result_register: str + emit: Callable[[str], None] + + +Intrinsic = Callable[[IntrinsicArgs], None] + +all_intrinsics: dict[str, Intrinsic] = {} + + +def _intrinsic(name: str) -> Callable[[Intrinsic], Intrinsic]: + """Function decorator that registers that function as an intrinsic.""" + def wrapper(f: Intrinsic) -> Intrinsic: + assert name not in all_intrinsics + all_intrinsics[name] = f + return f + return wrapper + + +@_intrinsic("unary_-") +def unary_minus(a: IntrinsicArgs) -> None: + a.emit(f'movq {a.arg_refs[0]}, {a.result_register}') + a.emit(f'negq {a.result_register}') + + +@_intrinsic("unary_not") +def unary_not(a: IntrinsicArgs) -> None: + a.emit(f'movq {a.arg_refs[0]}, {a.result_register}') + a.emit(f'xorq $1, {a.result_register}') + + +@_intrinsic("+") +def plus(a: IntrinsicArgs) -> None: + if a.result_register != a.arg_refs[0]: + a.emit(f'movq {a.arg_refs[0]}, {a.result_register}') + a.emit(f'addq {a.arg_refs[1]}, {a.result_register}') + + +@_intrinsic("-") +def minus(a: IntrinsicArgs) -> None: + if a.result_register != a.arg_refs[0]: + a.emit(f'movq {a.arg_refs[0]}, {a.result_register}') + a.emit(f'subq {a.arg_refs[1]}, {a.result_register}') + + +@_intrinsic("*") +def multiply(a: IntrinsicArgs) -> None: + if a.result_register != a.arg_refs[0]: + a.emit(f'movq {a.arg_refs[0]}, {a.result_register}') + a.emit(f'imulq {a.arg_refs[1]}, {a.result_register}') + + +@_intrinsic("/") +def divide(a: IntrinsicArgs) -> None: + a.emit(f'movq {a.arg_refs[0]}, %rax') + a.emit('cqto') # TODO: explain + a.emit(f'idivq {a.arg_refs[1]}') + if a.result_register != '%rax': + a.emit(f'movq %rax, {a.result_register}') + + +@_intrinsic("%") +def remainder(a: IntrinsicArgs) -> None: + a.emit(f'movq {a.arg_refs[0]}, %rax') + a.emit('cqto') + a.emit(f'idivq {a.arg_refs[1]}') + if a.result_register != '%rdx': + a.emit(f'movq %rdx, {a.result_register}') + + +@_intrinsic("==") +def eq(a: IntrinsicArgs) -> None: + _int_comparison(a, 'sete') + + +@_intrinsic("!=") +def ne(a: IntrinsicArgs) -> None: + _int_comparison(a, 'setne') + + +@_intrinsic("<") +def lt(a: IntrinsicArgs) -> None: + _int_comparison(a, 'setl') + + +@_intrinsic("<=") +def le(a: IntrinsicArgs) -> None: + _int_comparison(a, 'setle') + + +@_intrinsic(">") +def gt(a: IntrinsicArgs) -> None: + _int_comparison(a, 'setg') + + +@_intrinsic(">=") +def ge(a: IntrinsicArgs) -> None: + _int_comparison(a, 'setge') + + +def _int_comparison(a: IntrinsicArgs, setcc_insn: str) -> None: + a.emit('xor %rax, %rax') + a.emit(f'movq {a.arg_refs[0]}, %rdx') + a.emit(f'cmpq {a.arg_refs[1]}, %rdx') + a.emit(f'{setcc_insn} %al') + if a.result_register != '%rax': + a.emit(f'movq %rax, {a.result_register}') diff --git a/src/compiler/ir.py b/src/compiler/ir.py new file mode 100644 index 0000000..b668171 --- /dev/null +++ b/src/compiler/ir.py @@ -0,0 +1,86 @@ +from __future__ import annotations +from dataclasses import dataclass, fields +from typing import Any +from compiler.tokenizer import Location + +@dataclass(frozen=True) +class IRVar: + """Represents the name of a memory location or built-in.""" + name: str + + def __str__(self) -> str: + return self.name + + +@dataclass(frozen=True) +class Instruction(): + """Base class for IR instructions.""" + location: Location + + def __str__(self) -> str: + """Returns a string representation similar to + our IR code examples, e.g. 'LoadIntConst(3, x1)'""" + def format_value(v: Any) -> str: + if isinstance(v, list): + return f'[{", ".join(format_value(e) for e in v)}]' + else: + return str(v) + + args = ', '.join( + format_value(getattr(self, field.name)) + for field in fields(self) + if field.name != 'location' + ) + + return f'{type(self).__name__}({args})' + +@dataclass(frozen=True) +class LoadBoolConst(Instruction): + """Loads a boolean constant value to `dest`.""" + value: bool + dest: IRVar + +@dataclass(frozen=True) +class LoadIntConst(Instruction): + """Loads an int constant value to `dest`.""" + value: int + dest: IRVar + +@dataclass(frozen=True) +class Copy(Instruction): + """Copies a value from one variable to another.""" + source: IRVar + dest: IRVar + +@dataclass(frozen=True) +class Call(Instruction): + """Calls a function or built-in.""" + fun: IRVar + args: list[IRVar] + dest: IRVar + +@dataclass(frozen=True) +class Jump(Instruction): + """Unconditionally continues execution from the given label.""" + label: Label + +@dataclass(frozen=True) +class CondJump(Instruction): + """Continues execution from `then_label` if `cond` is true, otherwise from `else_label`.""" + cond: IRVar + then_label: Label + else_label: Label + +@dataclass(frozen=True) +class Return(Instruction): + """Returns the value of a function.""" + result: IRVar + +@dataclass(frozen=True) +class LoadCustomFuncArgs(Instruction): + args: list[IRVar] + +@dataclass(frozen=True) +class Label(Instruction): + """Marks the destination of a jump instruction.""" + name: str diff --git a/src/compiler/ir_generator.py b/src/compiler/ir_generator.py new file mode 100644 index 0000000..42507e6 --- /dev/null +++ b/src/compiler/ir_generator.py @@ -0,0 +1,424 @@ +from __future__ import annotations +from compiler.ir import IRVar, Instruction, LoadIntConst, LoadBoolConst, Call, Label, CondJump, Jump, Copy, Return, LoadCustomFuncArgs +import compiler.ast as ast +from compiler.types import Bool_Instance, Int_Instance, Type, Unit_Instance, FunType, Any_Instance +from compiler.tokenizer import L, Location +import uuid + +def print_instructions(ls: list[Instruction]) -> str: + output_str = '' + for instruction in ls: + output_str += f'{instruction.__str__()}\n' + return output_str + +class SymTab[T]: + locals: dict[str, T] + parent: SymTab | None + + def __init__(self, locals: dict = {}, parent: SymTab | None = None) -> None: + self.locals = locals + self.parent = parent + + def get_top_level(self) -> SymTab: + if self.parent is None: + return self + return self.parent.get_top_level() + + def assign(self, var_name: str, var_value: T) -> None: + self.locals[var_name] = var_value + + def assign_recursive(self, var_name: str, var_value: T) -> None: + if self.locals.get(var_name) is not None: + self.locals[var_name] = var_value + else: + if self.parent is not None: + self.parent.assign_recursive(var_name, var_value) + else: + raise Exception() + + def lookup(self, var_name: str) -> T: + lookup_result = self.locals.get(var_name) + if lookup_result is not None: + return lookup_result + else: + raise Exception() + + def lookup_recursive(self, var_name: str) -> T: + lookup_result = self.locals.get(var_name) + if lookup_result is not None: + return lookup_result + else: + if self.parent is not None: + return self.parent.lookup_recursive(var_name) + else: + raise Exception(var_name) + +binary_op_common_func_type = FunType(params=[Int_Instance, Int_Instance], result=Int_Instance) +binary_op_int_comparison_func_type = FunType(params=[Int_Instance, Int_Instance], result=Bool_Instance) +binary_op_and_or_func_type = FunType(params=[Bool_Instance, Bool_Instance], result=Bool_Instance) +binary_op_eq_func_type = FunType(params=[Any_Instance, Any_Instance], result=Bool_Instance) + +unary_op_int_func_type = FunType(params=[Int_Instance], result=Int_Instance) +unary_op_bool_func_type = FunType(params=[Bool_Instance], result=Bool_Instance) + +print_int_func_type = FunType(params=[Int_Instance], result=Unit_Instance) +print_bool_func_type = FunType(params=[Bool_Instance], result=Unit_Instance) +read_int_func_type = FunType(params=[], result=Int_Instance) + +root_types: dict[IRVar, Type] = { + IRVar('and'): binary_op_and_or_func_type, + IRVar('or') : binary_op_and_or_func_type, + IRVar('+'): binary_op_common_func_type, + IRVar('-'): binary_op_common_func_type, + IRVar('*'): binary_op_common_func_type, + IRVar('/'): binary_op_common_func_type, + IRVar('%'): binary_op_common_func_type, + IRVar('<'): binary_op_int_comparison_func_type, + IRVar('>'): binary_op_int_comparison_func_type, + IRVar('>='): binary_op_int_comparison_func_type, + IRVar('<='): binary_op_int_comparison_func_type, + IRVar('!='): binary_op_eq_func_type, + IRVar('=='): binary_op_eq_func_type, + + IRVar('unary_-'): unary_op_int_func_type, + IRVar('unary_not'): unary_op_bool_func_type, + + IRVar('print_int'): print_int_func_type, + IRVar('print_bool'): print_bool_func_type, + IRVar('read_int'): read_int_func_type, +} + +def generate_ir( + # 'root_types' parameter should map all global names + # like 'print_int' and '+' to their types. + root_symtab: SymTab[IRVar], + root_types: dict[IRVar, Type], + root_expr: ast.Expression, + is_func_def: bool = False, + custom_func_ir_vars: list[IRVar] = [] +) -> list[Instruction]: + + var_types: dict[IRVar, Type] = root_types.copy() + + # 'var_unit' is used when an expression's type is 'Unit'. + var_unit = IRVar('unit') + var_types[var_unit] = Unit_Instance + + def new_var(t: Type) -> IRVar: + # Create a new unique IR variable and + # add it to var_types + new_ir_var = IRVar(str(uuid.uuid4()).replace('-', '')) + var_types[new_ir_var] = t + return new_ir_var + + def new_label(location: Location = L) -> Label: + return Label(location, str(uuid.uuid4()).replace('-', '')) + + # We collect the IR instructions that we generate + # into this list. + ins: list[Instruction] = [] + + # This function visits an AST node, + # appends IR instructions to 'ins', + # and returns the IR variable where + # the emitted IR instructions put the result. + # + # It uses a symbol table to map local variables + # (which may be shadowed) to unique IR variables. + # The symbol table will be updated in the same way as + # in the interpreter and type checker. + def visit(st: SymTab[IRVar], lst: SymTab[Label], expr: ast.Expression) -> IRVar: + loc = expr.location + + match expr: + case ast.Literal(): + # Create an IR variable to hold the value, + # and emit the correct instruction to + # load the constant value. + match expr.value: + case bool(): + var = new_var(Bool_Instance) + ins.append(LoadBoolConst(loc, expr.value, var)) + case int(): + var = new_var(Int_Instance) + ins.append(LoadIntConst(loc, expr.value, var)) + case None: + var = var_unit + case _: + raise Exception(f'{loc}: unsupported literal: {type(expr.value)}') + + # Return the variable that holds + # the loaded value. + return var + + case ast.Identifier(): + # Look up the IR variable that corresponds to + # the source code variable. + return st.lookup_recursive(expr.name) + + case ast.BinaryOp(): + + if expr.op == '=': + if not isinstance(expr.left, ast.Identifier): + raise Exception() + ir_var_left = visit(st, lst, expr.left) + ir_var_right = visit(st, lst, expr.right) + ins.append(Copy(loc, ir_var_right, ir_var_left)) + return ir_var_right + + elif expr.op == 'and': + v_result = new_var(Bool_Instance) + + v_left = visit(st, lst, expr.left) + + label_1 = new_label() + label_2 = new_label() + label_3 = new_label() + + ins.append(CondJump(loc, v_left, label_1, label_2)) + + ins.append(label_1) + v_right = visit(st, lst, expr.right) + ins.append(Copy(loc, v_right, v_result)) + ins.append(Jump(loc, label_3)) + + ins.append(label_2) + ins.append(Copy(loc, v_left, v_result)) + + ins.append(label_3) + + return v_result + + elif expr.op == 'or': + v_result = new_var(Bool_Instance) + + v_left = visit(st, lst, expr.left) + + label_1 = new_label() + label_2 = new_label() + label_3 = new_label() + + ins.append(CondJump(loc, v_left, label_1, label_2)) + + ins.append(label_1) + ins.append(Copy(loc, v_left, v_result)) + ins.append(Jump(loc, label_3)) + + ins.append(label_2) + v_right = visit(st, lst, expr.right) + ins.append(Copy(loc, v_right, v_result)) + + ins.append(label_3) + + return v_result + + else: + # Ask the symbol table to return the variable that refers + # to the operator to call. + var_op = st.get_top_level().lookup(expr.op) + # Recursively emit instructions to calculate the operands. + var_left = visit(st, lst, expr.left) + var_right = visit(st, lst, expr.right) + # Generate variable to hold the result. + var_result = new_var(expr.type) + # Emit a Call instruction that writes to that variable. + ins.append(Call(loc, var_op, [var_left, var_right], var_result)) + return var_result + + case ast.UnaryOp(): + var_op = st.get_top_level().lookup(f'unary_{expr.op}') + var_right = visit(st, lst, expr.right) + var_result = new_var(expr.type) + ins.append(Call(loc, var_op, [var_right], var_result)) + return var_result + + case ast.If(): + if expr.else_exp is None: + # Create (but don't emit) some jump targets. + l_then = new_label(expr.then_exp.location) + l_end = new_label() + + # Recursively emit instructions for + # evaluating the condition. + var_cond = visit(st, lst, expr.cond_exp) + # Emit a conditional jump instruction + # to jump to 'l_then' or 'l_end', + # depending on the content of 'var_cond'. + ins.append(CondJump(loc, var_cond, l_then, l_end)) + + # Emit the label that marks the beginning of + # the "then" branch. + ins.append(l_then) + # Recursively emit instructions for the "then" branch. + visit(st, lst, expr.then_exp) + + # Emit the label that we jump to + # when we don't want to go to the "then" branch. + ins.append(l_end) + + # An if-then expression doesn't return anything, so we + # return a special variable "unit". + return var_unit + + else: + then_label = new_label(expr.then_exp.location) + else_label = new_label(expr.else_exp.location) + end_label = new_label() + + result_ir_var = new_var(expr.type) + + ir_var_cond = visit(st, lst, expr.cond_exp) + ins.append(CondJump(loc, ir_var_cond, then_label, else_label)) + + ins.append(then_label) + then_ir_var = visit(st, lst, expr.then_exp) + ins.append(Copy(loc, then_ir_var, result_ir_var)) + ins.append(Jump(loc, end_label)) + + ins.append(else_label) + else_ir_var = visit(st, lst, expr.else_exp) + ins.append(Copy(loc, else_ir_var, result_ir_var)) + + ins.append(end_label) + + return result_ir_var + + case ast.While(): + l0 = new_label(expr.while_exp.location) + l1 = new_label(expr.do_exp.location) + l2 = new_label() + new_lst: SymTab[Label] = SymTab(locals={}, parent=lst) + new_lst.assign('start_label', l0) + new_lst.assign('end_label', l2) + + ins.append(l0) + x1 = visit(st, new_lst, expr.while_exp) + ins.append(CondJump(loc, x1, l1, l2)) + + ins.append(l1) + visit(st, new_lst, expr.do_exp) + ins.append(Jump(loc, l0)) + + ins.append(l2) + + return var_unit + + case ast.FunctionCall(): + arg_ir_vars = [] + for arg in expr.args: + arg_ir_var = visit(st, lst, arg) + arg_ir_vars.append(arg_ir_var) + ft = st.lookup_recursive(expr.name) + riv = new_var(expr.type) + ins.append(Call(loc, ft, [*arg_ir_vars], riv)) + return riv + + case ast.Block(): + last_expr = expr.statements[len(expr.statements) - 1] + new_st: SymTab = SymTab(locals={}, parent=st) + if isinstance(last_expr, ast.Literal) and last_expr.value is None: + for index, statement in enumerate(expr.statements): + if index == len(expr.statements) - 1: + return var_unit + else: + visit(new_st, lst, statement) + else: + for index, statement in enumerate(expr.statements): + if index == len(expr.statements) - 1: + last_statement_ir_var = visit(new_st, lst, statement) + return last_statement_ir_var + else: + visit(new_st, lst, statement) + + case ast.Var(): + if st.locals.get(expr.name) is not None: + raise Exception() + + ir_var_new = new_var(expr.type) + ir_var_for_new_var_value = visit(st, lst, expr.value) + ins.append(Copy(loc, ir_var_for_new_var_value, ir_var_new)) + st.assign(expr.name, ir_var_new) + return var_unit + + case ast.Break(): + end_label = lst.lookup('end_label') + ins.append(Jump(loc, end_label)) + return var_unit + + case ast.Continue(): + start_label = lst.lookup('start_label') + ins.append(Jump(loc, start_label)) + return var_unit + + case ast.Return(): + rs_ir_var = visit(st, lst, expr.result) + ins.append(Return(loc, rs_ir_var)) + return var_unit + + return var_unit # This line is only added for the linter; it should never be actually reached. + + # # Convert 'root_types' into a SymTab + # # that maps all available global names to + # # IR variables of the same name. + # # In the Assembly generator stage, we will give + # # definitions for these globals. For now, + # # they just need to exist. + # root_symtab = SymTab[IRVar]() + # for v in root_types.keys(): + # root_symtab.assign(v.name, v) + + loop_symtab = SymTab[Label]() + + if is_func_def: + ins.append(LoadCustomFuncArgs(L, custom_func_ir_vars)) + + # Start visiting the AST from the root. + var_final_result = visit(root_symtab, loop_symtab, root_expr) + + if is_func_def: + ins.append(Return(root_expr.location, var_final_result)) + else: + if var_types[var_final_result] == Int_Instance: + print_int_ir_var = root_symtab.lookup('print_int') + ins.append(Call(root_expr.location, print_int_ir_var, [var_final_result], new_var(Int_Instance))) + elif var_types[var_final_result] == Bool_Instance: + print_bool_ir_var = root_symtab.lookup('print_bool') + ins.append(Call(root_expr.location, print_bool_ir_var, [var_final_result], new_var(Bool_Instance))) + + return ins + +def generate_ir_from_module(module: ast.Module) -> dict[str, list[Instruction]]: + result: dict[str, list[Instruction]] = {} + + extended_types_for_main = root_types.copy() + + for function_def in module.function_defs: + fdiv = IRVar(function_def.name) + extended_types_for_main.update({ fdiv: function_def.return_type }) + + for function_def in module.function_defs: + + extended_types = extended_types_for_main.copy() + custom_func_ir_vars: list[IRVar] = [] + + for arg in function_def.args: + irv = IRVar(arg.name) + custom_func_ir_vars.append(irv) + extended_types.update({ irv: arg.type }) + + fdiv = IRVar(function_def.name) + extended_types.update({ fdiv: function_def.return_type}) + + root_symtab = SymTab[IRVar]() + for v in extended_types.keys(): + root_symtab.assign(v.name, v) + + result[function_def.name] = generate_ir(root_symtab, extended_types, function_def.block, True, custom_func_ir_vars) + + + root_symtab = SymTab[IRVar]() + for v in extended_types_for_main.keys(): + root_symtab.assign(v.name, v) + result['main'] = generate_ir(root_symtab, extended_types_for_main, module.block) + + return result diff --git a/src/compiler/parser.py b/src/compiler/parser.py new file mode 100644 index 0000000..f4c2fba --- /dev/null +++ b/src/compiler/parser.py @@ -0,0 +1,337 @@ +from compiler.tokenizer import Token +import compiler.ast as ast +from compiler.types import FunType, ParamsType +from compiler.type_checker import Unit_Instance, Int_Instance, Bool_Instance, Type + +def parse(tokens: list[Token]) -> ast.Module: + + if len(tokens) == 0: + raise Exception('len(tokens) == 0') + + # This keeps track of which token we're looking at. + pos = 0 + + end_token = Token( + location=tokens[-1].location, + type='end', + text='', + ) + + # 'peek()' returns the token at 'pos', + # or a special 'end' token if we're past the end + # of the token list. + # This way we don't have to worry about going past + # the end elsewhere. + def peek() -> Token: + nonlocal pos, end_token + if pos < len(tokens): + return tokens[pos] + else: + return end_token + + last_consumed_token: Token | None = None + + # 'consume(expected)' returns the token at 'pos' + # and moves 'pos' forward. + # + # If the optional parameter 'expected' is given, + # it checks that the token being consumed has that text. + # If 'expected' is a list, then the token must have + # one of the texts in the list. + def consume(expected: str | list[str] | None = None) -> Token: + nonlocal pos, last_consumed_token + token = peek() + if isinstance(expected, str) and token.text != expected: + raise Exception(f'{token.location}: expected "{expected}", got "{token.text}"') + if isinstance(expected, list) and token.text not in expected: + comma_separated = ', '.join([f'"{e}"' for e in expected]) + raise Exception(f'{token.location}: expected one of: {comma_separated}') + pos += 1 + last_consumed_token = token + return token + + + binary_operators = [ + (['='], 'right'), + (['or'], 'left'), + (['and'], 'left'), + (['==', '!='], 'left'), + (['<', '<=', '>', '>='], 'left'), + (['+', '-'], 'left'), + (['*', '/', '%'], 'left'), + ] + + + def parse_expression(binary_precedence_level: int = 0, allow_var: bool = False) -> ast.Expression: + nonlocal pos + + if binary_precedence_level == len(binary_operators): + if peek().text == 'not': + token = consume('not') + return ast.UnaryOp(location=token.location, op='not', right=parse_expression(binary_precedence_level)) + elif peek().text == '-': + token = consume('-') + return ast.UnaryOp(location=token.location, op='-', right=parse_expression(binary_precedence_level)) + else: + return parse_factor(allow_var=allow_var) + + else: + left = parse_expression(binary_precedence_level + 1, allow_var) + precedence_level_array, associativity = binary_operators[binary_precedence_level] + while peek().text in precedence_level_array: + operator_token = consume() + operator = operator_token.text + next_level = 1 if associativity == 'left' else 0 + right = parse_expression(binary_precedence_level + next_level, False) + left = ast.BinaryOp( + location=left.location, + left=left, + op=operator, + right=right + ) + return left + + + def parse_factor(allow_var: bool = False) -> ast.Expression: + if peek().text == 'if': + return parse_if() + elif peek().text == 'while': + return parse_while() + elif peek().text == 'var': + if allow_var: + return parse_var() + else: + raise Exception('"var" is only allowed directly inside blocks {} and in top-level expressions') + elif peek().text == '(': + return parse_parenthesized() + elif peek().text == '{': + return parse_block() + elif peek().type == 'int_literal': + return parse_int_literal() + elif peek().type == 'bool_literal': + return parse_boolean_literal() + elif peek().type == 'identifier': + return parse_identifier() + elif peek().text == 'break': + token = consume('break') + return ast.Break(location=token.location, type=Unit_Instance) + elif peek().text == 'continue': + token = consume('continue') + return ast.Continue(location=token.location, type=Unit_Instance) + elif peek().text == 'return': + token = consume('return') + result = parse_expression() + return ast.Return(location=token.location, result=result) + else: + raise Exception(f'{peek().location}: expected "(", an integer literal or an identifier') + + + def parse_if() -> ast.If: + token = consume('if') + if_exp = parse_expression() + then_exp = None + else_exp = None + if peek().text == 'then': + consume('then') + then_exp = parse_expression() + else: + raise Exception(f'{peek().location}: "if" not followed by "then"') + + if peek().text == 'else': + consume('else') + else_exp = parse_expression() + + return ast.If(location=token.location, cond_exp=if_exp, then_exp=then_exp, else_exp=else_exp) + + + def parse_while() -> ast.While: + token = consume('while') + while_exp = parse_expression() + do_exp = None + if peek().text == 'do': + consume('do') + do_exp = parse_expression() + else: + raise Exception(f'{peek().location}: "while" not followed by "do"') + + return ast.While(location=token.location, while_exp=while_exp, do_exp=do_exp) + + + def parse_parenthesized() -> ast.Expression: + consume('(') + # Recursively call the top level parsing function + # to parse whatever is inside the parentheses. + expr = parse_expression() + if peek().text != ')': + raise Exception(f'{peek().location}: "(" not followed by ")"') + consume(')') + return expr + + + def parse_var() -> ast.Var: + token = consume('var') + var_identifier = parse_identifier() + var_type = None + if peek().text == ':': + consume(':') + var_type = parse_type() + consume('=') + var_value = parse_expression() + return ast.Var(location=token.location, name=var_identifier.name, value=var_value, type_f=var_type) + + + def parse_type() -> Type: + next_token = peek() + if next_token.text == 'Unit': + consume('Unit') + return Unit_Instance + elif next_token.text == 'Int': + consume('Int') + return Int_Instance + elif next_token.text == 'Bool': + consume('Bool') + return Bool_Instance + elif next_token.text == '(': + return parse_function_type() + else: + raise Exception() + + + def parse_function_type() -> FunType: + consume('(') + params: list[Type | list['ParamsType']] = [] + if peek().text != ')': + params.append(parse_type()) + while peek().text == ',': + consume(',') + params.append(parse_type()) + consume(')') + consume('=>') + result = parse_type() + return FunType(params=params, result=result) + + def parse_function_def() -> ast.FunctionDef: + consume('fun') + function_name = consume().text + consume('(') + function_args: list[ast.FunctionArg] = [] + while peek().text != ')': + arg_name = consume() + consume(':') + arg_type = parse_type() + function_args.append(ast.FunctionArg(arg_name.text, arg_type)) + if peek().text == ',': + consume(',') + consume(')') + consume(':') + return_type = parse_type() + block = parse_block() + return ast.FunctionDef(function_name, function_args, return_type, block) + + def parse_block() -> ast.Block: + token = consume('{') + + # Empty block + if peek().text == '}': + consume('}') + return ast.Block(location=token.location, statements=[]) + + statements = [parse_expression(allow_var=True)] + last_token = tokens[pos - 1] + + while True: + current_peek = peek().text + if current_peek == ';': + consume(';') + if peek().text == '}': + temp_token = consume('}') + statements.append(ast.Literal(location=temp_token.location, value=None)) + return ast.Block(location=token.location, statements=statements) + else: + next_exp = parse_expression(allow_var=True) + statements.append(next_exp) + last_token = tokens[pos - 1] + elif current_peek == '{': + if last_token.text == ';' or last_token.text == '}': + block_exp = parse_block() + statements.append(block_exp) + last_token = tokens[pos - 1] + else: + raise Exception() + elif current_peek == '}': + break + else: + if last_token.text == '}': + next_exp = parse_expression(allow_var=True) + statements.append(next_exp) + last_token = tokens[pos - 1] + else: + raise Exception(f'{peek().location}: unexpected token "{current_peek}" after statement ending with "{last_token.text}"') + + consume('}') + return ast.Block(location=token.location, statements=statements) + + + # This is the parsing function for integer literals. + # It checks that we're looking at an integer literal token, + # moves past it, and returns a 'Literal' AST node + # containing the integer from the token. + def parse_int_literal() -> ast.Literal: + if peek().type != 'int_literal': + raise Exception(f'{peek().location}: expected an integer literal') + token = consume() + return ast.Literal(location=token.location, value=int(token.text)) + + def parse_boolean_literal() -> ast.Literal: + if peek().type != 'bool_literal': + raise Exception(f'{peek().location}: expected an boolean literal') + token = consume() + return ast.Literal(location=token.location, value=True if token.text == 'true' else False) + + + def parse_identifier() -> ast.Identifier | ast.FunctionCall: + if peek().type != 'identifier': + raise Exception(f'{peek().location}: expected an identifier') + token = consume() + + # Function call detected + if peek().text == '(': + consume('(') + args: list[ast.Expression] = [] + if peek().text != ')': # Check if there are arguments + args.append(parse_expression()) + while peek().text == ',': + consume(',') + args.append(parse_expression()) + consume(')') + return ast.FunctionCall(location=token.location, name=str(token.text), args=args) + + return ast.Identifier(location=token.location, name=str(token.text)) + + + statements: list[ast.Expression] = [] + function_defs: list[ast.FunctionDef] = [] + location = tokens[0].location + + while peek().text == 'fun': + function_defs.append(parse_function_def()) + + while peek().type != 'end': + exp = parse_expression(allow_var=True) + statements.append(exp) + if peek().text == ';': + consume(';') + if peek().type == 'end': + statements.append(ast.Literal(location=peek().location, value=None)) + else: + if peek().type != 'end': + if last_consumed_token is not None and last_consumed_token.text == '}': + pass + else: + raise Exception(f'{peek().location}: expected ";" or "{{", got "{peek().text}"') + + if pos < len(tokens): + raise Exception(f'{peek().location}: unexpected tokens remaining') + + return ast.Module(function_defs, ast.Block(location=location, statements=statements)) + diff --git a/src/compiler/tokenizer.py b/src/compiler/tokenizer.py new file mode 100644 index 0000000..0b3ba3d --- /dev/null +++ b/src/compiler/tokenizer.py @@ -0,0 +1,88 @@ +import re +from dataclasses import dataclass +from typing import Any + + +@dataclass +class SourceLocation: + file: str + line: int + column: int + + def __eq__(self, value: Any) -> bool: + if value == L: + return True + elif isinstance(value, SourceLocation) and self.file == value.file and self.line == value.line and self.column == value.column: + return True + else: + return False +@dataclass +class SpecialSourceLocation: + def __eq__(self, _: Any) -> bool: + return True + +L = SpecialSourceLocation() + +type Location = SourceLocation | SpecialSourceLocation + +@dataclass +class Token: + location: Location + type: str + text: str + +def tokenize(source_code: str) -> list[Token]: + whitespace_re = r'\s' + comment_re = r'(?:\/\/|#)[^\n]*(?=\n|$)' + + command_re = r'continue|break|fun|return' + type_re = r'Int|Bool|Unit' + integer_re = r'[0-9]+' + boolean_re = r'true|false' + punctuation_re = r'=>|\(|\)|{|}|,|;|:' + operator_re = r'and|not|or|>=|<=|!=|==|\+|-|\*|/|%|=|<|>' + identifier_re = r'[a-zA-Z_][a-zA-Z0-9_]*' + + token_re_list = [ + (command_re, 'command'), + (type_re, 'type'), + (integer_re, 'int_literal'), + (boolean_re, 'bool_literal'), + (punctuation_re, 'punctuation'), + (operator_re, 'operator'), + (identifier_re, 'identifier') + ] + + tokens = [] + source_code_lines = source_code.splitlines() + + for row_index, line in enumerate(source_code_lines): + col_index = 0 + while col_index < len(line): + + whitespace_match = re.match(whitespace_re, line[col_index:]) + comment_match = re.match(comment_re, line[col_index:]) + + if whitespace_match: + whitespace = whitespace_match.group(0) + col_index += len(whitespace) + continue + elif comment_match: + break + + curr_col_index = col_index + + for tup in token_re_list: + rex = tup[0] + rex_type = tup[1] + match = re.match(rex, line[col_index:]) + if match: + token = match.group(0) # This is the token that matched. + tokens.append(Token(SourceLocation('', row_index, col_index), rex_type, token)) + col_index += len(token) + break + + if curr_col_index == col_index: + raise Exception() + + return tokens diff --git a/src/compiler/type_checker.py b/src/compiler/type_checker.py new file mode 100644 index 0000000..565b658 --- /dev/null +++ b/src/compiler/type_checker.py @@ -0,0 +1,301 @@ +from __future__ import annotations +import compiler.ast as ast +from compiler.types import Int, Bool, Unit, Type, FunType, Int_Instance, Bool_Instance, Unit_Instance +from compiler.tokenizer import L + +class TypeSymTab: + locals: dict[str, Type | dict[str, FunType]] + parent: TypeSymTab | None + + def __init__(self, locals: dict[str, Type | dict[str, FunType] ] = {}, parent: TypeSymTab | None = None) -> None: + self.locals = locals + self.parent = parent + + binary_op_common_func_type = FunType(params=[Int_Instance, Int_Instance], result=Int_Instance) + binary_op_int_comparison_func_type = FunType(params=[Int_Instance, Int_Instance], result=Bool_Instance) + binary_op_and_or_func_type = FunType(params=[Bool_Instance, Bool_Instance], result=Bool_Instance) + binary_op_eq_func_type = FunType(params=[[Int_Instance, Bool_Instance], [Int_Instance, Bool_Instance]], result=Bool_Instance) + + unary_op_int_func_type = FunType(params=[Int_Instance], result=Int_Instance) + unary_op_bool_func_type = FunType(params=[Bool_Instance], result=Bool_Instance) + + print_int_func_type = FunType(params=[Int_Instance], result=Unit_Instance) + print_bool_func_type = FunType(params=[Bool_Instance], result=Unit_Instance) + read_int_func_type = FunType(params=[], result=Int_Instance) + + if parent is None: + self.locals.update({ + 'binary_ops': { + 'and': binary_op_and_or_func_type, + 'or' : binary_op_and_or_func_type, + '+': binary_op_common_func_type, + '-': binary_op_common_func_type, + '*': binary_op_common_func_type, + '/': binary_op_common_func_type, + '%': binary_op_common_func_type, + '<': binary_op_int_comparison_func_type, + '>': binary_op_int_comparison_func_type, + '>=': binary_op_int_comparison_func_type, + '<=': binary_op_int_comparison_func_type, + '!=': binary_op_eq_func_type, + '==': binary_op_eq_func_type, + }, + 'unary_ops': { + '-': unary_op_int_func_type, + 'not': unary_op_bool_func_type + }, + 'built_in_functions': { + 'print_int': print_int_func_type, + 'print_bool': print_bool_func_type, + 'read_int': read_int_func_type + } + }) + + def get_top_level(self) -> TypeSymTab: + if self.parent is None: + return self + return self.parent.get_top_level() + + def assign(self, var_name: str, var_type: Type) -> None: + self.locals[var_name] = var_type + + def assign_recursive(self, var_name: str, var_type: Type) -> None: + if self.locals.get(var_name) is not None: + self.locals[var_name] = var_type + else: + if self.parent is not None: + self.parent.assign_recursive(var_name, var_type) + else: + raise Exception() + + def lookup(self, var_name: str) -> Type: + res = self.locals.get(var_name) + if res is not None: + assert isinstance(res, Int | Bool | Unit | FunType) + return res + else: + raise Exception() + + def lookup_recursive(self, var_name: str) -> Type: + res = self.locals.get(var_name) + if res is not None: + assert isinstance(res, Int | Bool | Unit | FunType) + return res + else: + if self.parent is not None: + return self.parent.lookup_recursive(var_name) + else: + raise Exception() + +def typecheck_main(node: ast.Expression, sym_tab: TypeSymTab) -> Type: + + match node: + + case ast.Literal(): + if type(node.value) is bool: + return Bool_Instance + elif type(node.value) is int: + return Int_Instance + elif node.value is None: + return Unit_Instance + else: + raise Exception() + + case ast.Identifier(): + return sym_tab.lookup_recursive(node.name) + + case ast.BinaryOp(): + t1 = typecheck(node.left, sym_tab) + t2 = typecheck(node.right, sym_tab) + + binary_ops = sym_tab.get_top_level().locals['binary_ops'] + assert isinstance(binary_ops, dict) + func_type = binary_ops.get(node.op) + + if func_type is None: + + if node.op != '=': + raise Exception() + + # = operator + else: + if not isinstance(node.left, ast.Identifier): + raise Exception() + if t1 != t2: + raise Exception() + return t2 + + else: + assert isinstance(func_type, FunType) + assert isinstance(func_type.params, list) + + if isinstance(func_type.params[0], list): + if not any(t1 == possible_type for possible_type in func_type.params[0]): + raise Exception() + else: + if t1 != func_type.params[0]: + raise Exception() + + if isinstance(func_type.params[1], list): + if not any(t2 == possible_type for possible_type in func_type.params[1]): + raise Exception() + else: + if t2 != func_type.params[1]: + raise Exception() + + # == and != + if isinstance(func_type.params[0], list) and isinstance(func_type.params[1], list): + if t1 != t2: + raise Exception() + + return func_type.result + + case ast.UnaryOp(): + t1 = typecheck(node.right, sym_tab) + unary_ops = sym_tab.get_top_level().locals['unary_ops'] + assert isinstance(unary_ops, dict) + func_type = unary_ops.get(node.op) + assert isinstance(func_type, FunType) + assert isinstance(func_type.params, list) + if t1 != func_type.params[0]: + raise Exception() + return func_type.result + + case ast.If(): + e1 = node.cond_exp + e2 = node.then_exp + e3 = node.else_exp + + # if-then + if e3 is None: + return Unit_Instance + + # if-then-else + else: + t1 = typecheck(e1, sym_tab) + if t1 != Bool_Instance: + raise Exception() + t2 = typecheck(e2, sym_tab) + t3 = typecheck(e3, sym_tab) + if t2 != t3: + raise Exception() + return t2 + + case ast.While(): + if typecheck(node.while_exp, sym_tab) != Bool_Instance: + raise Exception() + return Unit_Instance + + case ast.FunctionCall(): + built_in_functions = sym_tab.get_top_level().locals['built_in_functions'] + assert isinstance(built_in_functions, dict) + func_type = built_in_functions.get(node.name) + + # Built-in functions + if func_type is not None: + assert isinstance(func_type, FunType) + assert isinstance(func_type.params, list) + if len(node.args) != len(func_type.params): + raise Exception() + if len(node.args) == 1: + if typecheck(node.args[0], sym_tab) != func_type.params[0]: + raise Exception() + return Unit_Instance + else: # len(node.args) == 0, i.e. node.name == 'read_int' + return Int_Instance + + else: + custom_func_type = sym_tab.lookup_recursive(node.name) + assert isinstance(custom_func_type, FunType) + assert isinstance(custom_func_type.params, list) + if len(node.args) != len(custom_func_type.params): + raise Exception() + for i in range(0, len(node.args)): + if typecheck(node.args[i], sym_tab) != custom_func_type.params[i]: + raise Exception() + return custom_func_type.result + + case ast.Block(): + if len(node.statements) == 0: + return Unit_Instance + + last_node = node.statements[len(node.statements) - 1] + + if last_node == ast.Literal(location=L, value=None): + new_sym_tab = TypeSymTab(locals={}, parent=sym_tab) + for idx, statement in enumerate(node.statements): + if idx == len(node.statements) - 1: + return Unit_Instance + else: + typecheck(node=statement, sym_tab=new_sym_tab) + + raise Exception() # The loop should have returned a result and this line should not have been reached. + + else: + new_sym_tab = TypeSymTab(locals={}, parent=sym_tab) + for idx, statement in enumerate(node.statements): + if idx == len(node.statements) - 1: + return typecheck(node=statement, sym_tab=new_sym_tab) + else: + typecheck(node=statement, sym_tab=new_sym_tab) + + raise Exception() # The loop should have returned a result and this line should not have been reached. + + case ast.Var(): + if node.type_f is not None: + value_type = typecheck(node=node.value, sym_tab=sym_tab) + if value_type is Unit_Instance or node.type_f != value_type: + raise Exception() + sym_tab.assign(var_name=node.name, var_type=node.type_f) + return node.type_f + else: + value_type = typecheck(node=node.value, sym_tab=sym_tab) + sym_tab.assign(var_name=node.name, var_type=value_type) + return value_type + + case ast.Break(): + return Unit_Instance + + case ast.Continue(): + return Unit_Instance + + case ast.Var(): + return Unit_Instance + + case ast.Return(): + assert sym_tab.parent is not None + if typecheck(node.result, sym_tab) != sym_tab.parent.lookup_recursive('return'): + raise Exception() + return Unit_Instance + + case _: + raise Exception() + + + +def typecheck(node: ast.Expression, sym_tab: TypeSymTab | None = None) -> Type: + + if sym_tab is None: + sym_tab = TypeSymTab() + + resulting_type = typecheck_main(node, sym_tab) + + node.type = resulting_type + + return resulting_type + + +def typecheck_module(module: ast.Module) -> None: + st_block = TypeSymTab(locals={}, parent=None) + + for function_def in module.function_defs: + st_block.assign(function_def.name, FunType(params=[arg.type for arg in function_def.args], result=function_def.return_type)) + + for function_def in module.function_defs: + st = TypeSymTab(locals={}, parent=st_block) + for arg in function_def.args: + st.assign(arg.name, arg.type) + st.assign('return', function_def.return_type) + typecheck(function_def.block, st) + + typecheck(module.block, st_block) \ No newline at end of file diff --git a/src/compiler/types.py b/src/compiler/types.py new file mode 100644 index 0000000..25ecc0c --- /dev/null +++ b/src/compiler/types.py @@ -0,0 +1,29 @@ +from __future__ import annotations +from dataclasses import dataclass + +class Int: + pass + +class Bool: + pass + +class Unit: + pass + +Int_Instance = Int() +Bool_Instance = Bool() +Unit_Instance = Unit() + +class Any(Int, Bool, Unit): + pass + +Any_Instance = Any() + +type Type = Int | Bool | Unit | FunType + +ParamsType = Type | list[Type | list['ParamsType']] + +@dataclass +class FunType: + params: ParamsType + result: Type diff --git a/test-gadget.py b/test-gadget.py new file mode 100755 index 0000000..2aca9e2 --- /dev/null +++ b/test-gadget.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# Runs the correct test-gadget client program in .test-gadget/test-gadget-client-$PLATFORM + +import os +import platform +import sys +from pathlib import Path + + +def get_platform_binary() -> str: + system = platform.system().lower() + if system == "darwin": + return "test-gadget-client-macos" + elif system == "windows": + return "test-gadget-client-windows.exe" + elif system == "linux": + return "test-gadget-client-linux" + else: + print(f"Unsupported platform: {system}", file=sys.stderr) + sys.exit(1) + + +script_dir = Path(__file__).parent +dist_dir = script_dir / ".test-gadget" +binary = dist_dir / get_platform_binary() + +if not binary.exists(): + print(f"Program not found: {binary}", file=sys.stderr) + sys.exit(1) + +os.execv(str(binary), [str(binary)] + sys.argv[1:]) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assembly_generator_test.py b/tests/assembly_generator_test.py new file mode 100644 index 0000000..3225fe2 --- /dev/null +++ b/tests/assembly_generator_test.py @@ -0,0 +1,135 @@ +from compiler.assembly_generator import Locals, generate_assembly +from compiler.ir import IRVar, LoadIntConst, LoadBoolConst, Copy, CondJump, Label, Instruction, Call +from compiler.tokenizer import L +from typing import List + +def test_assembly_generator_locals_initialization() -> None: + variables = [IRVar('x'), IRVar('y'), IRVar('z')] + locals = Locals(variables) + + assert locals.get_ref(variables[0]) == '-8(%rbp)' + assert locals.get_ref(variables[1]) == '-16(%rbp)' + assert locals.get_ref(variables[2]) == '-24(%rbp)' + + assert locals.stack_used() == 24 # 3 variables * 8 bytes + +# def test_assembly_generator_load_int_const() -> None: +# ir_var = IRVar('x') +# instructions: List[Instruction] = [ +# LoadIntConst(L, 42, ir_var) +# ] + +# asm = generate_assembly(instructions) +# assert 'movq $42, -8(%rbp)' in asm + +# def test_assembly_generator_load_bool_const() -> None: +# instructions: List[Instruction] = [ +# LoadBoolConst(L, True, IRVar('a')), +# LoadBoolConst(L, False, IRVar('b')) +# ] + +# asm = generate_assembly(instructions) +# assert 'movq $1, -8(%rbp)' in asm # True +# assert 'movq $0, -16(%rbp)' in asm # False + +# def test_assembly_generator_copy() -> None: +# src = IRVar('src') +# dest = IRVar('dest') +# instructions: List[Instruction] = [ +# Copy(L, src, dest) +# ] + +# asm = generate_assembly(instructions) +# assert 'movq -8(%rbp), %rax' in asm +# assert 'movq %rax, -16(%rbp)' in asm + +# def test_assembly_generator_cond_jump() -> None: +# cond_var = IRVar('cond') +# then_label = Label(L, 'Lthen') +# else_label = Label(L, 'Lelse') +# instructions: List[Instruction] = [ +# CondJump(L, cond_var, then_label, else_label) +# ] + +# asm = generate_assembly(instructions) +# assert 'cmpq $0, -8(%rbp)' in asm +# assert 'jne .LLthen' in asm +# assert 'jmp .LLelse' in asm + +# def test_assembly_generator_function_prologue_epilogue() -> None: +# instructions: List[Instruction] = [ +# LoadIntConst(L, 0, IRVar('dummy')) +# ] + +# asm = generate_assembly(instructions) +# assert 'pushq %rbp' in asm +# assert 'movq %rsp, %rbp' in asm +# assert 'subq $8, %rsp' in asm +# assert 'movq %rbp, %rsp' in asm +# assert 'popq %rbp' in asm +# assert 'ret' in asm + +# def test_assembly_generator_intrinsic_plus() -> None: +# # IR: Call('+', [x, y], result) +# x = IRVar('x') +# y = IRVar('y') +# result = IRVar('result') +# instructions = [ +# LoadIntConst(L, 3, x), +# LoadIntConst(L, 5, y), +# Call(L, IRVar('+'), [x, y], result) +# ] +# asm = generate_assembly(instructions) +# assert 'addq' in asm +# assert 'movq' in asm +# assert 'callq' not in asm # Intrinsic should not use call + +# def test_assembly_generator_intrinsic_divide() -> None: +# # IR: Call('/', [a, b], result) +# a = IRVar('a') +# b = IRVar('b') +# instructions = [ +# LoadIntConst(L, 10, a), +# LoadIntConst(L, 2, b), +# Call(L, IRVar('/'), [a, b], IRVar('result')) +# ] +# asm = generate_assembly(instructions) +# assert 'idivq' in asm +# assert 'cqto' in asm + +# def test_assembly_generator_function_call_one_arg() -> None: +# # IR: Call(print_int, [x], _) +# x = IRVar('x') +# instructions = [ +# LoadIntConst(L, 42, x), +# Call(L, IRVar('print_int'), [x], IRVar('unused')) +# ] +# asm = generate_assembly(instructions) +# assert 'movq -8(%rbp), %rdi' in asm # Assuming x is at -8(%rbp) +# assert 'callq print_int' in asm + +# def test_assembly_generator_function_call_six_args() -> None: +# # IR: Call(func, [a,b,c,d,e,f], _) +# args = [IRVar(f'arg{i}') for i in range(6)] +# instructions = [ +# *[LoadIntConst(L, i, arg) for i, arg in enumerate(args)], +# Call(L, IRVar('func'), args, IRVar('result')) +# ] +# asm = generate_assembly(instructions) +# expected_regs = ['%rdi', '%rsi', '%rdx', '%rcx', '%r8', '%r9'] +# for i, reg in enumerate(expected_regs): +# assert f'movq -{8*(i+1)}(%rbp), {reg}' in asm +# assert 'callq func' in asm + +# def test_assembly_generator_comparison_intrinsic() -> None: +# # IR: Call('==', [x, y], result) +# x = IRVar('x') +# y = IRVar('y') +# instructions = [ +# LoadIntConst(L, 5, x), +# LoadIntConst(L, 5, y), +# Call(L, IRVar('=='), [x, y], IRVar('result')) +# ] +# asm = generate_assembly(instructions) +# assert 'cmpq' in asm +# assert 'sete' in asm diff --git a/tests/interpreter_test.py b/tests/interpreter_test.py new file mode 100644 index 0000000..ee5447f --- /dev/null +++ b/tests/interpreter_test.py @@ -0,0 +1,419 @@ +import pytest +from typing import Any +from compiler.parser import parse +from compiler.tokenizer import tokenize +from compiler.interpreter import interpret, SymTab, Unit + +def test_interpreter_addition() -> None: + exp = '2 + 3' + tokens = tokenize(exp) + ast = parse(tokens) + assert interpret(ast.block) == 5 + +def test_interpreter_boolean_coercion() -> None: + exp = '1 + (2 < 3)' + tokens = tokenize(exp) + ast = parse(tokens) + assert interpret(ast.block) == 2 + +def test_interpreter_nested_blocks() -> None: + code = ''' + { + var x = 5; + { + var y = 10; + x + y + } + } + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 15 + +def test_interpreter_variable_shadowing() -> None: + code = ''' + var x = 1; + { + var x = 2; + x + } + x + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 1 + +def test_interpreter_assignment_in_outer_scope() -> None: + code = ''' + var x = 5; + { + x = x + 1; + x + } + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 6 + +def test_interpreter_assignment_operator() -> None: + code = 'var x = 0; x = 5; x' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 5 + +def test_interpreter_operator_precedence() -> None: + code = '10 - 4 - 3' # (10-4)-3 = 3 + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 3 + +def test_interpreter_logical_operators() -> None: + code = 'true and false or true' # (true and false) or true = true + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) is True + +def test_interpreter_comparison_equality() -> None: + code = '5 > 3 == true' # (5>3) == true → true + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) is True + +def test_interpreter_unary_operators() -> None: + code = 'not (5 < 3)' # not(false) → true + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) is True + +def test_interpreter_block_scoping() -> None: + code = ''' + { + var a = 10; + { + var b = 20; + a + b + } + } + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 30 + +def test_interpreter_variable_redeclaration() -> None: + code = ''' + { + var x = 1; + var x = 2; + x + } + ''' + tokens = tokenize(code) + ast = parse(tokens) + with pytest.raises(Exception): + interpret(ast.block) + +def test_interpreter_while_loop() -> None: + code = ''' + var counter = 3; + while counter > 0 do { + counter = counter - 1 + }; + counter + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 0 + +def test_interpreter_empty_block() -> None: + code = '{}' + tokens = tokenize(code) + ast = parse(tokens) + assert isinstance(interpret(ast.block), Unit) + +def test_interpreter_block_last_semicolon() -> None: + code = ''' + { + var x = 5; + x; + } + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert isinstance(interpret(ast.block), Unit) + +def test_interpreter_if_without_else_returns_unit() -> None: + code = 'if false then 5' + tokens = tokenize(code) + ast = parse(tokens) + assert isinstance(interpret(ast.block), Unit) + +def test_interpreter_if_else_true_branch() -> None: + code = 'if true then 10 else 20' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 10 + +def test_interpreter_if_else_false_branch() -> None: + code = 'if false then 10 else 20' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 20 + +def test_interpreter_function_call_raises_error(capsys: pytest.CaptureFixture) -> None: + code = 'print_int(5)' + tokens = tokenize(code) + ast = parse(tokens) + assert isinstance(interpret(ast.block), Unit) + captured = capsys.readouterr() + assert captured.out == '5\n' + +def test_interpreter_division_modulo_operators() -> None: + code = '8 / 3 + 8 % 3' # 2 + 2 = 4 + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 4 + +def test_interpreter_assignment_to_non_identifier_raises() -> None: + code = '5 = 10' + tokens = tokenize(code) + ast = parse(tokens) + with pytest.raises(Exception): + interpret(ast.block) + +def test_interpreter_while_loop_zero_iterations() -> None: + code = ''' + var x = 3; + while x > 5 do { + x = x - 1 + }; + x + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 3 + +def test_interpreter_nested_if_expression() -> None: + code = '(if true then 5 else 3) + 2' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 7 + +def test_interpreter_unary_not_on_boolean() -> None: + code = 'not true' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) is False + +def test_interpreter_unary_negate_non_int_raises() -> None: + code = '-true' + tokens = tokenize(code) + ast = parse(tokens) + with pytest.raises(Exception): + interpret(ast.block) + +def test_interpreter_and_short_circuit() -> None: + code = ''' + var evaluated = false; + false and { evaluated = true; true }; + evaluated + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) is False + +def test_interpreter_and_evaluates_both() -> None: + code = ''' + var evaluated = false; + true and { evaluated = true; true }; + evaluated + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) is True + +def test_interpreter_or_short_circuit() -> None: + code = ''' + var evaluated = false; + true or { evaluated = true; true }; + evaluated + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) is False + +def test_interpreter_or_evaluates_both() -> None: + code = ''' + var evaluated = false; + false or { evaluated = true; true }; + evaluated + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) is True + +def test_interpreter_unary_negate_int() -> None: + code = '-5' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == -5 + +def test_interpreter_unary_negate_bool_raises() -> None: + code = '-true' + tokens = tokenize(code) + ast = parse(tokens) + with pytest.raises(Exception): + interpret(ast.block) + +def test_interpreter_unary_not_on_int_raises() -> None: + code = 'not 5' + tokens = tokenize(code) + ast = parse(tokens) + with pytest.raises(Exception): + interpret(ast.block) + +def test_interpreter_binary_operator_precedence() -> None: + code = '1 + 2 * 3 == 7' # 1 + (2 * 3) == 7 → true + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) is True + +def test_interpreter_binary_operator_associativity() -> None: + code = '10 - 4 - 3' # (10 - 4) - 3 = 3 + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 3 + +def test_interpreter_binary_operator_equality() -> None: + code = '5 == 5 and 3 != 4' # true and true → true + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) is True + +def test_interpreter_binary_operator_comparison() -> None: + code = '5 < 10 and 10 >= 10' # true and true → true + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) is True + +def test_interpreter_binary_operator_division_by_zero_raises() -> None: + code = '5 / 0' + tokens = tokenize(code) + ast = parse(tokens) + with pytest.raises(Exception): + interpret(ast.block) + +def test_interpreter_binary_operator_modulo_by_zero_raises() -> None: + code = '5 % 0' + tokens = tokenize(code) + ast = parse(tokens) + with pytest.raises(Exception): + interpret(ast.block) + +def test_interpreter_while_loop_multiple_iterations() -> None: + code = ''' + var x = 3; + while x > 0 do { + x = x - 1 + }; + x + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 0 + +def test_interpreter_block_with_multiple_statements() -> None: + code = ''' + { + var x = 5; + var y = 10; + x + y + } + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 15 + +def test_interpreter_block_with_trailing_semicolon() -> None: + code = ''' + { + var x = 5; + x; + } + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert isinstance(interpret(ast.block), Unit) + +def test_interpreter_block_with_nested_blocks() -> None: + code = ''' + { + var x = 5; + { + var y = 10; + x + y + } + } + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 15 + +def test_interpreter_variable_declaration_in_block() -> None: + code = ''' + { + var x = 5; + x + } + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 5 + +def test_interpreter_variable_redeclaration_in_block_raises() -> None: + code = ''' + { + var x = 5; + var x = 10; + x + } + ''' + tokens = tokenize(code) + ast = parse(tokens) + with pytest.raises(Exception): + interpret(ast.block) + +def test_interpreter_assignment_in_block() -> None: + code = ''' + { + var x = 5; + x = x + 1; + x + } + ''' + tokens = tokenize(code) + ast = parse(tokens) + assert interpret(ast.block) == 6 + +def test_interpreter_function_call_print_int() -> None: + code = 'print_int(5)' + tokens = tokenize(code) + ast = parse(tokens) + assert isinstance(interpret(ast.block), Unit) + +def test_interpreter_function_call_print_bool() -> None: + code = 'print_bool(true)' + tokens = tokenize(code) + ast = parse(tokens) + assert isinstance(interpret(ast.block), Unit) + +def test_interpreter_function_call_read_int() -> None: + code = 'read_int()' + tokens = tokenize(code) + ast = parse(tokens) + # Mock input for read_int + import builtins + original_input: Any = builtins.input + func: Any = lambda: '42' + builtins.input = func + assert interpret(ast.block) == 42 + builtins.input = original_input diff --git a/tests/ir_generator_test.py b/tests/ir_generator_test.py new file mode 100644 index 0000000..15bf76b --- /dev/null +++ b/tests/ir_generator_test.py @@ -0,0 +1,12 @@ +from compiler.type_checker import typecheck_module +from compiler.tokenizer import L, tokenize +from compiler.parser import parse +from compiler.ir_generator import generate_ir_from_module, print_instructions, root_types + +def test_ir_generator_basic () -> None: + expr_str = '1 + 2 * 3' + tokens = tokenize(expr_str) + ast = parse(tokens) + typecheck_module(ast) + main_instructions = generate_ir_from_module(ast)['main'] + assert print_instructions(main_instructions) != '' diff --git a/tests/parser_test.py b/tests/parser_test.py new file mode 100644 index 0000000..f6ca50e --- /dev/null +++ b/tests/parser_test.py @@ -0,0 +1,2505 @@ +import pytest +from compiler.parser import parse +from compiler.tokenizer import tokenize, L, SourceLocation +import compiler.ast as ast +from compiler.types import Int_Instance, Bool_Instance, FunType + +def test_parser_bad_eoi() -> None: + tokens = tokenize('a + b * 5 - (3 + 7) / 10 4') + with pytest.raises(Exception): + parse(tokens) + + tokens = tokenize('a + b * 5 - (3 + 7) / 10 )') + with pytest.raises(Exception): + parse(tokens) + + tokens = tokenize('a + b * 5 - (3 + 7) / 10 asldkjaskjkajkdjsad') + with pytest.raises(Exception): + parse(tokens) + +# def test_parser_missing_exp() -> None: +# tokens = tokenize('a + b * 5 - (3 + 7) / 10 >=') +# with pytest.raises(Exception) as exc_info: +# parse(tokens) +# assert str(exc_info.value).endswith(': expected "(", an integer literal or an identifier') + +# def test_parser_invalid_exp() -> None: +# tokens = tokenize('a + b * 5 - (3 + 7) / ;') +# with pytest.raises(Exception) as exc_info: +# parse(tokens) +# assert str(exc_info.value).endswith(': expected "(", an integer literal or an identifier') + +# def test_parser_missing_closing_parenthesis() -> None: +# tokens = tokenize('a + b * 5 - (3 + 7) / 10 + (5') +# with pytest.raises(Exception) as exc_info: +# parse(tokens) +# assert str(exc_info.value).endswith('"(" not followed by ")"') + +# def test_parser_empty_token_list() -> None: +# with pytest.raises(Exception) as exc_info: +# parse([]) +# assert str(exc_info.value) == 'len(tokens) == 0' + +# def test_parser_maths_exp() -> None: +# tokens = tokenize('a + b * 5 - (3 + 7) / 10') +# assert parse(tokens) == ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='+', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='*', +# right=ast.Literal(location=L, value=5) +# ) +# ), +# op='-', +# right=ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Literal(location=L, value=3), +# op='+', +# right=ast.Literal(location=L, value=7) +# ), +# op='/', +# right=ast.Literal(location=L, value=10) +# ) +# ) + +# def test_parser_simple_function_call() -> None: +# tokens = tokenize('f(x)') +# parsed = parse(tokens) +# expected = ast.FunctionCall( +# location=L, +# name='f', +# args=[ast.Identifier(location=L, name='x')] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_nested_function_call() -> None: +# tokens = tokenize('f(g(x))') +# parsed = parse(tokens) +# expected = ast.FunctionCall( +# location=L, +# name='f', +# args=[ +# ast.FunctionCall( +# location=L, +# name='g', +# args=[ast.Identifier(location=L, name='x')] +# ) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_function_call_with_multiple_args() -> None: +# tokens = tokenize('f(x, y, z)') +# parsed = parse(tokens) +# expected = ast.FunctionCall( +# location=L, +# name='f', +# args=[ +# ast.Identifier(location=L, name='x'), +# ast.Identifier(location=L, name='y'), +# ast.Identifier(location=L, name='z') +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_function_call_with_nested_expressions() -> None: +# tokens = tokenize('f(x + y, g(z))') +# parsed = parse(tokens) +# expected = ast.FunctionCall( +# location=L, +# name='f', +# args=[ +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='+', +# right=ast.Identifier(location=L, name='y') +# ), +# ast.FunctionCall( +# location=L, +# name='g', +# args=[ast.Identifier(location=L, name='z')] +# ) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_function_call_as_part_of_expression() -> None: +# tokens = tokenize('a * f(x) + 5') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='*', +# right=ast.FunctionCall( +# location=L, +# name='f', +# args=[ast.Identifier(location=L, name='x')] +# ) +# ), +# op='+', +# right=ast.Literal(location=L, value=5) +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_nested_parentheses_with_function_call() -> None: +# tokens = tokenize('(f(x))') +# parsed = parse(tokens) +# expected = ast.FunctionCall( +# location=L, +# name='f', +# args=[ast.Identifier(location=L, name='x')] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_complex_nested_function_calls() -> None: +# tokens = tokenize('f(g(h(1), 2), x + y)') +# parsed = parse(tokens) +# expected = ast.FunctionCall( +# location=L, +# name='f', +# args=[ +# ast.FunctionCall( +# location=L, +# name='g', +# args=[ +# ast.FunctionCall( +# location=L, +# name='h', +# args=[ast.Literal(location=L, value=1)] +# ), +# ast.Literal(location=L, value=2) +# ] +# ), +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='+', +# right=ast.Identifier(location=L, name='y') +# ) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_if_then_else() -> None: +# tokens = tokenize('if a then b + c else x * y') +# assert parse(tokens) == ast.If( +# location=L, +# cond_exp=ast.Identifier(location=L, name='a'), +# then_exp=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='+', +# right=ast.Identifier(location=L, name='c') +# ), +# else_exp=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='*', +# right=ast.Identifier(location=L, name='y') +# ) +# ) + +# tokens = tokenize('if if a then d then b + c else x * y') +# assert parse(tokens) == ast.If( +# location=L, +# cond_exp=ast.If( +# location=L, +# cond_exp=ast.Identifier(location=L, name='a'), +# then_exp=ast.Identifier(location=L, name='d'), +# else_exp=None +# ), +# then_exp=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='+', +# right=ast.Identifier(location=L, name='c') +# ), +# else_exp=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='*', +# right=ast.Identifier(location=L, name='y') +# ) +# ) + +# def test_parser_while() -> None: +# tokens = tokenize('while a do b + c') +# assert parse(tokens) == ast.While( +# location=L, +# while_exp=ast.Identifier(location=L, name='a'), +# do_exp=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='+', +# right=ast.Identifier(location=L, name='c') +# ) +# ) + +# def test_parser_unary_ops() -> None: +# tokens = tokenize('not a') +# assert parse(tokens) == ast.UnaryOp(location=L, op='not', right=ast.Identifier(location=L, name='a')) + +# tokens = tokenize('-a') +# assert parse(tokens) == ast.UnaryOp(location=L, op='-', right=ast.Identifier(location=L, name='a')) + +# tokens = tokenize('b - f(-a)') +# assert parse(tokens) == ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='-', +# right=ast.FunctionCall( +# location=L, +# name='f', +# args=[ast.UnaryOp(location=L, op='-', right=ast.Identifier(location=L, name='a'))] +# ) +# ) + +# tokens = tokenize('not not not a') +# assert parse(tokens) == ast.UnaryOp( +# location=L, +# op='not', +# right=ast.UnaryOp( +# location=L, +# op='not', +# right=ast.UnaryOp( +# location=L, +# op='not', +# right=ast.Identifier(location=L, name='a') +# ) +# ) +# ) + +# tokens = tokenize('---a') +# assert parse(tokens) == ast.UnaryOp( +# location=L, +# op='-', +# right=ast.UnaryOp( +# location=L, +# op='-', +# right=ast.UnaryOp( +# location=L, +# op='-', +# right=ast.Identifier(location=L, name='a') +# ) +# ) +# ) + +# tokens = tokenize('if if -a then not d then b + c else x * y') +# assert parse(tokens) == ast.If( +# location=L, +# cond_exp=ast.If( +# location=L, +# cond_exp=ast.UnaryOp(location=L, op='-', right=ast.Identifier(location=L, name='a')), +# then_exp=ast.UnaryOp(location=L, op='not', right=ast.Identifier(location=L, name='d')), +# else_exp=None +# ), +# then_exp=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='+', +# right=ast.Identifier(location=L, name='c') +# ), +# else_exp=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='*', +# right=ast.Identifier(location=L, name='y') +# ) +# ) + +# def test_parser_simple_assignment() -> None: +# tokens = tokenize('a = b') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='=', +# right=ast.Identifier(location=L, name='b') +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_chained_assignment() -> None: +# tokens = tokenize('a = b = c') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='=', +# right=ast.Identifier(location=L, name='c') +# ) +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_assignment_with_arithmetic() -> None: +# tokens = tokenize('a = b + c') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='+', +# right=ast.Identifier(location=L, name='c') +# ) +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_chained_assignment_with_arithmetic() -> None: +# tokens = tokenize('a = b = c + d') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='c'), +# op='+', +# right=ast.Identifier(location=L, name='d') +# ) +# ) +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_assignment_with_logic() -> None: +# tokens = tokenize('a = b and c') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='and', +# right=ast.Identifier(location=L, name='c') +# ) +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_chained_assignment_with_logic() -> None: +# tokens = tokenize('a = b = c or d') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='c'), +# op='or', +# right=ast.Identifier(location=L, name='d') +# ) +# ) +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_assignment_with_nested_expressions() -> None: +# tokens = tokenize('a = (b + c) * d') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='+', +# right=ast.Identifier(location=L, name='c') +# ), +# op='*', +# right=ast.Identifier(location=L, name='d') +# ) +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_assignment_with_function_call() -> None: +# tokens = tokenize('a = f(x)') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='=', +# right=ast.FunctionCall( +# location=L, +# name='f', +# args=[ast.Identifier(location=L, name='x')] +# ) +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_chained_assignment_with_function_call() -> None: +# tokens = tokenize('a = b = f(x)') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='=', +# right=ast.FunctionCall( +# location=L, +# name='f', +# args=[ast.Identifier(location=L, name='x')] +# ) +# ) +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_chained_assignment_with_mixed_precedence() -> None: +# tokens = tokenize('a = b + c = d * e') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='+', +# right=ast.Identifier(location=L, name='c') +# ), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='d'), +# op='*', +# right=ast.Identifier(location=L, name='e') +# ) +# ) +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_blocks() -> None: +# assert parse(tokenize(""" +# { +# f(a); +# x = y; +# f(x) +# } +# """)) == ast.Block( +# location=L, +# statements=[ +# ast.FunctionCall(location=L, name='f', args=[ast.Identifier(location=L, name='a')]), +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='=', +# right=ast.Identifier(location=L, name='y') +# ), +# ast.FunctionCall(location=L, name='f', args=[ast.Identifier(location=L, name='x')]), +# ] +# ) + +# assert parse(tokenize(""" +# { +# f(a); +# x = y; +# f(x); +# } +# """)) == ast.Block( +# location=L, +# statements=[ +# ast.FunctionCall(location=L, name='f', args=[ast.Identifier(location=L, name='a')]), +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='=', +# right=ast.Identifier(location=L, name='y') +# ), +# ast.FunctionCall(location=L, name='f', args=[ast.Identifier(location=L, name='x')]), +# ast.Literal(location=L, value=None) +# ] +# ) + +# def test_parser_block_empty() -> None: +# tokens = tokenize('{}') +# parsed = parse(tokens) +# expected = ast.Block(location=L, statements=[]) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_block_single_statement_no_semicolon() -> None: +# tokens = tokenize('{ a }') +# parsed = parse(tokens) +# expected = ast.Block(location=L, statements=[ast.Identifier(location=L, name='a')]) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_block_single_statement_trailing_semicolon() -> None: +# tokens = tokenize('{ a; }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Identifier(location=L, name='a'), +# ast.Literal(location=L, value=None) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_block_multiple_statements() -> None: +# tokens = tokenize('{ a; b; c }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Identifier(location=L, name='a'), +# ast.Identifier(location=L, name='b'), +# ast.Identifier(location=L, name='c') +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_block_multiple_statements_trailing_semicolon() -> None: +# tokens = tokenize('{ a; b; c; }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Identifier(location=L, name='a'), +# ast.Identifier(location=L, name='b'), +# ast.Identifier(location=L, name='c'), +# ast.Literal(location=L, value=None) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_block_nested() -> None: +# tokens = tokenize('{ { a; } }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Block( +# location=L, +# statements=[ +# ast.Identifier(location=L, name='a'), +# ast.Literal(location=L, value=None) +# ] +# ) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_block_with_if() -> None: +# tokens = tokenize('{ if a then { b; } else { c } }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.If( +# location=L, +# cond_exp=ast.Identifier(location=L, name='a'), +# then_exp=ast.Block( +# location=L, +# statements=[ +# ast.Identifier(location=L, name='b'), +# ast.Literal(location=L, value=None) +# ] +# ), +# else_exp=ast.Block( +# location=L, +# statements=[ast.Identifier(location=L, name='c')] +# ) +# ) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_block_with_assignment() -> None: +# tokens = tokenize('{ a = b; c = d; }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='=', +# right=ast.Identifier(location=L, name='b') +# ), +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='c'), +# op='=', +# right=ast.Identifier(location=L, name='d') +# ), +# ast.Literal(location=L, value=None) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_block_missing_closing_brace() -> None: +# tokens = tokenize('{ a') +# with pytest.raises(Exception): +# parse(tokens) + +# def test_parser_block_invalid_empty_statement() -> None: +# tokens = tokenize('{ ; }') +# with pytest.raises(Exception) as exc_info: +# parse(tokens) +# assert 'expected "(", an integer literal or an identifier' in str(exc_info.value) + +# def test_parser_block_with_function_call() -> None: +# tokens = tokenize('{ f(a); }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.FunctionCall(location=L, name='f', args=[ast.Identifier(location=L, name='a')]), +# ast.Literal(location=L, value=None) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_block_with_mixed_expressions() -> None: +# tokens = tokenize('{ a + b; if c then d; }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='+', +# right=ast.Identifier(location=L, name='b') +# ), +# ast.If( +# location=L, +# cond_exp=ast.Identifier(location=L, name='c'), +# then_exp=ast.Identifier(location=L, name='d'), +# else_exp=None +# ), +# ast.Literal(location=L, value=None) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_var_declaration_top_level() -> None: +# tokens = tokenize('var a = 5') +# parsed = parse(tokens) +# expected = ast.Var( +# location=L, +# name='a', +# value=ast.Literal(location=L, value=5), +# type_f=None +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_var_declaration_in_block() -> None: +# tokens = tokenize('{ var b = 10 }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Var(location=L, name='b', value=ast.Literal(location=L, value=10), type_f=None) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_multiple_var_declarations_in_block() -> None: +# tokens = tokenize('{ var x = 1; var y = 2 }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Var(location=L, name='x', value=ast.Literal(location=L, value=1), type_f=None), +# ast.Var(location=L, name='y', value=ast.Literal(location=L, value=2), type_f=None) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_var_declaration_with_expression() -> None: +# tokens = tokenize('var sum = a + b') +# parsed = parse(tokens) +# expected = ast.Var( +# location=L, +# name='sum', +# value=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='+', +# right=ast.Identifier(location=L, name='b') +# ), +# type_f=None +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_var_declaration_trailing_semicolon() -> None: +# tokens = tokenize('{ var x = 5; }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Var(location=L, name='x', value=ast.Literal(location=L, value=5), type_f=None), +# ast.Literal(location=L, value=None) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_var_declaration_with_nested_blocks() -> None: +# tokens = tokenize('{ { var x = 1 } }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Block( +# location=L, +# statements=[ +# ast.Var(location=L, name='x', value=ast.Literal(location=L, value=1), type_f=None) +# ] +# ) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_var_declaration_with_assignment() -> None: +# tokens = tokenize('var x = y = 5') +# parsed = parse(tokens) +# expected = ast.Var( +# location=L, +# name='x', +# value=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='y'), +# op='=', +# right=ast.Literal(location=L, value=5) +# ), +# type_f=None +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_var_declaration_missing_initializer() -> None: +# tokens = tokenize('var x') +# with pytest.raises(Exception) as exc_info: +# parse(tokens) +# assert 'expected "="' in str(exc_info.value) + +# def test_parser_var_declaration_in_invalid_location() -> None: +# tokens = tokenize('f(var x = 5)') +# with pytest.raises(Exception) as exc_info: +# parse(tokens) +# assert '"var" is only allowed directly inside blocks {} and in top-level expressions' in str(exc_info.value) + +# def test_parser_var_declaration_invalid_identifier() -> None: +# tokens = tokenize('var 123 = 5') +# with pytest.raises(Exception) as exc_info: +# parse(tokens) +# assert 'expected an identifier' in str(exc_info.value) + +# def test_parser_var_declaration_with_boolean_literal() -> None: +# tokens = tokenize('var flag = true') +# parsed = parse(tokens) +# expected = ast.Var( +# location=L, +# name='flag', +# value=ast.Literal(location=L, value=True), +# type_f=None +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_var_declaration_in_if_block() -> None: +# tokens = tokenize('if x then { var y = 10 } else { var z = 20 }') +# parsed = parse(tokens) +# expected = ast.If( +# location=L, +# cond_exp=ast.Identifier(location=L, name='x'), +# then_exp=ast.Block( +# location=L, +# statements=[ +# ast.Var(location=L, name='y', value=ast.Literal(location=L, value=10), type_f=None) +# ] +# ), +# else_exp=ast.Block( +# location=L, +# statements=[ +# ast.Var(location=L, name='z', value=ast.Literal(location=L, value=20), type_f=None) +# ] +# ) +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_var_declaration_with_function_call_value() -> None: +# tokens = tokenize('var result = calculate(5)') +# parsed = parse(tokens) +# expected = ast.Var( +# location=L, +# name='result', +# value=ast.FunctionCall( +# location=L, +# name='calculate', +# args=[ast.Literal(location=L, value=5)] +# ), +# type_f=None +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_var_declaration_with_logical_expression() -> None: +# tokens = tokenize('var valid = a > 5 and b != 10') +# parsed = parse(tokens) +# expected = ast.Var( +# location=L, +# name='valid', +# value=ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='>', +# right=ast.Literal(location=L, value=5) +# ), +# op='and', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='!=', +# right=ast.Literal(location=L, value=10) +# ) +# ), +# type_f=None +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_var_declaration_with_mixed_block() -> None: +# tokens = tokenize(''' +# { +# var a = 1; +# b = a * 2; +# var c = b + 3 +# } +# ''') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Var(location=L, name='a', value=ast.Literal(location=L, value=1), type_f=None), +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='*', +# right=ast.Literal(location=L, value=2) +# ) +# ), +# ast.Var( +# location=L, +# name='c', +# value=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='+', +# right=ast.Literal(location=L, value=3) +# ), +# type_f=None +# ) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_var_declaration_in_while_block() -> None: +# tokens = tokenize('while true do { var counter = 0 }') +# parsed = parse(tokens) +# expected = ast.While( +# location=L, +# while_exp=ast.Literal(location=L, value=True), +# do_exp=ast.Block( +# location=L, +# statements=[ +# ast.Var(location=L, name='counter', value=ast.Literal(location=L, value=0), type_f=None) +# ] +# ) +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# # def test_parser_var_declaration_with_type_annotation() -> None: +# # tokens = tokenize('var x: Int = 5') +# # parsed = parse(tokens) +# # expected = ast.Var(location=L, name='x', value=ast.Literal(location=L, value=5), type_f=None) +# # assert parsed == expected, f'Expected type annotations to be ignored, got {parsed}' + +# def test_parser_var_declaration_followed_by_usage() -> None: +# tokens = tokenize(''' +# { +# var x = 10; +# x + 5 +# } +# ''') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Var(location=L, name='x', value=ast.Literal(location=L, value=10), type_f=None), +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='+', +# right=ast.Literal(location=L, value=5) +# ) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_var_declaration_in_nested_expression_error() -> None: +# tokens = tokenize('a + var x = 5') +# with pytest.raises(Exception) as exc_info: +# parse(tokens) +# assert 'allowed directly' in str(exc_info.value), 'Should prevent var in expressions' + +# def test_parser_var_declaration_missing_value_error() -> None: +# tokens = tokenize('var x =') +# with pytest.raises(Exception) as exc_info: +# parse(tokens) +# assert 'expected "(", an integer literal or an identifier' in str(exc_info.value) + +# def test_parser_block_auto_semicolon_after_brace_valid() -> None: +# tokens = tokenize('{ { a } { b } }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Block(location=L, statements=[ast.Identifier(location=L, name='a')]), +# ast.Block(location=L, statements=[ast.Identifier(location=L, name='b')]) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_consecutive_expressions_in_block_invalid() -> None: +# tokens = tokenize('{ a b }') +# with pytest.raises(Exception): +# parse(tokens) + +# def test_parser_block_after_if_without_semicolon_valid() -> None: +# tokens = tokenize('{ if true then { a } b }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.If( +# location=L, +# cond_exp=ast.Literal(location=L, value=True), +# then_exp=ast.Block(location=L, statements=[ast.Identifier(location=L, name='a')]), +# else_exp=None +# ), +# ast.Identifier(location=L, name='b') +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_block_after_if_with_semicolon_valid() -> None: +# tokens = tokenize('{ if true then { a }; b }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.If( +# location=L, +# cond_exp=ast.Literal(location=L, value=True), +# then_exp=ast.Block(location=L, statements=[ast.Identifier(location=L, name='a')]), +# else_exp=None +# ), +# ast.Identifier(location=L, name='b') +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_consecutive_after_if_block_invalid() -> None: +# tokens = tokenize('{ if true then { a } b c }') +# with pytest.raises(Exception): +# parse(tokens) + +# def test_parser_mixed_implicit_explicit_semicolons_valid() -> None: +# tokens = tokenize('{ if true then { a } b; c }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.If( +# location=L, +# cond_exp=ast.Literal(location=L, value=True), +# then_exp=ast.Block(location=L, statements=[ast.Identifier(location=L, name='a')]), +# else_exp=None +# ), +# ast.Identifier(location=L, name='b'), +# ast.Identifier(location=L, name='c') +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_if_else_with_implicit_semicolon_valid() -> None: +# tokens = tokenize('{ if true then { a } else { b } 3 }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.If( +# location=L, +# cond_exp=ast.Literal(location=L, value=True), +# then_exp=ast.Block(location=L, statements=[ast.Identifier(location=L, name='a')]), +# else_exp=ast.Block(location=L, statements=[ast.Identifier(location=L, name='b')]) +# ), +# ast.Literal(location=L, value=3) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_nested_blocks_in_assignment_valid() -> None: +# tokens = tokenize('x = { { f(a) } { b } }') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='=', +# right=ast.Block( +# location=L, +# statements=[ +# ast.Block( +# location=L, +# statements=[ast.FunctionCall(location=L, name='f', args=[ast.Identifier(location=L, name='a')])] +# ), +# ast.Block(location=L, statements=[ast.Identifier(location=L, name='b')]) +# ] +# ) +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_consecutive_expressions_non_brace_invalid() -> None: +# tokens = tokenize('{ a { b } }') +# with pytest.raises(Exception): +# parse(tokens) + +# def test_parser_var_declarations_missing_semicolon_invalid() -> None: +# tokens = tokenize('{ var x = 5 var y = 6 }') +# with pytest.raises(Exception) as exc_info: +# parse(tokens) +# assert 'unexpected token "var"' in str(exc_info.value), "Should require semicolon between var declarations" + +# def test_parser_function_call_followed_by_block_invalid() -> None: +# tokens = tokenize('{ f() { g() } }') +# with pytest.raises(Exception): +# parse(tokens) + +# def test_parser_while_block_followed_by_expression_valid() -> None: +# tokens = tokenize('{ while cond do { a } b }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.While( +# location=L, +# while_exp=ast.Identifier(location=L, name='cond'), +# do_exp=ast.Block(location=L, statements=[ast.Identifier(location=L, name='a')]) +# ), +# ast.Identifier(location=L, name='b') +# ] +# ) +# assert parsed == expected, "Should allow expression after while block without semicolon" + +# def test_parser_if_else_and_while_without_semicolon_valid() -> None: +# tokens = tokenize(''' +# { +# if a then { b } else { c } +# while d do { e } +# } +# ''') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.If( +# location=L, +# cond_exp=ast.Identifier(location=L, name='a'), +# then_exp=ast.Block(location=L, statements=[ast.Identifier(location=L, name='b')]), +# else_exp=ast.Block(location=L, statements=[ast.Identifier(location=L, name='c')]) +# ), +# ast.While( +# location=L, +# while_exp=ast.Identifier(location=L, name='d'), +# do_exp=ast.Block(location=L, statements=[ast.Identifier(location=L, name='e')]) +# ) +# ] +# ) +# assert parsed == expected, "Should allow consecutive blocks ending with } without semicolons" + +# def test_parser_var_declaration_followed_by_block_without_semicolon_invalid() -> None: +# tokens = tokenize('{ var x = 5 { y } }') +# with pytest.raises(Exception): +# parse(tokens) + +# def test_parser_nested_blocks_with_expressions_valid() -> None: +# tokens = tokenize(''' +# { +# { a; { b } } +# { c }; { d } +# } +# ''') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Block( +# location=L, +# statements=[ +# ast.Identifier(location=L, name='a'), +# ast.Block(location=L, statements=[ast.Identifier(location=L, name='b')]) +# ] +# ), +# ast.Block(location=L, statements=[ast.Identifier(location=L, name='c')]), +# ast.Block(location=L, statements=[ast.Identifier(location=L, name='d')]) +# ] +# ) +# assert parsed == expected, "Should handle nested blocks and semicolons correctly" + +# def test_parser_mixed_constructs_with_optional_semicolons_valid() -> None: +# tokens = tokenize(''' +# { +# if x then { y } else { z } +# while true do { var i = 0 } +# { f(a, b, 5, false) } +# g() +# } +# ''') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.If( +# location=L, +# cond_exp=ast.Identifier(location=L, name='x'), +# then_exp=ast.Block(location=L, statements=[ast.Identifier(location=L, name='y')]), +# else_exp=ast.Block(location=L, statements=[ast.Identifier(location=L, name='z')]) +# ), +# ast.While( +# location=L, +# while_exp=ast.Literal(location=L, value=True), +# do_exp=ast.Block( +# location=L, +# statements=[ +# ast.Var(location=L, name='i', value=ast.Literal(location=L, value=0), type_f=None) +# ] +# ) +# ), +# ast.Block( +# location=L, +# statements=[ +# ast.FunctionCall( +# location=L, +# name='f', +# args=[ +# ast.Identifier(location=L, name='a'), +# ast.Identifier(location=L, name='b'), +# ast.Literal(location=L, value=5), +# ast.Literal(location=L, value=False) +# ] +# ) +# ] +# ), +# ast.FunctionCall(location=L, name='g', args=[]) +# ] +# ) +# assert parsed == expected, "Should parse mixed constructs with optional semicolons" + +# def test_parser_deeply_nested_if_else() -> None: +# tokens = tokenize(''' +# if a then +# if b then +# if c then 1 +# else 2 +# else 3 +# else 4 +# ''') +# parsed = parse(tokens) +# expected = ast.If( +# location=L, +# cond_exp=ast.Identifier(location=L, name='a'), +# then_exp=ast.If( +# location=L, +# cond_exp=ast.Identifier(location=L, name='b'), +# then_exp=ast.If( +# location=L, +# cond_exp=ast.Identifier(location=L, name='c'), +# then_exp=ast.Literal(location=L, value=1), +# else_exp=ast.Literal(location=L, value=2) +# ), +# else_exp=ast.Literal(location=L, value=3) +# ), +# else_exp=ast.Literal(location=L, value=4) +# ) +# assert parsed == expected, "Should parse deeply nested if-else structures" + +# def test_parser_mixed_block_with_all_features() -> None: +# tokens = tokenize(''' +# { +# var x = f(a, b + c); +# while x > 0 do { +# if y then { +# z = z * 2; +# print(z) +# } else { +# z = z / 2 +# }; +# x = x - 1 +# }; +# x +# } +# ''') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Var( +# location=L, +# name='x', +# value=ast.FunctionCall( +# location=L, +# name='f', +# args=[ +# ast.Identifier(location=L, name='a'), +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='+', +# right=ast.Identifier(location=L, name='c') +# ) +# ] +# ), +# type_f=None +# ), +# ast.While( +# location=L, +# while_exp=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='>', +# right=ast.Literal(location=L, value=0) +# ), +# do_exp=ast.Block( +# location=L, +# statements=[ +# ast.If( +# location=L, +# cond_exp=ast.Identifier(location=L, name='y'), +# then_exp=ast.Block( +# location=L, +# statements=[ +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='z'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='z'), +# op='*', +# right=ast.Literal(location=L, value=2) +# ) +# ), +# ast.FunctionCall( +# location=L, +# name='print', +# args=[ast.Identifier(location=L, name='z')] +# ) +# ] +# ), +# else_exp=ast.Block( +# location=L, +# statements=[ +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='z'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='z'), +# op='/', +# right=ast.Literal(location=L, value=2) +# ) +# ) +# ] +# ) +# ), +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='-', +# right=ast.Literal(location=L, value=1) +# ) +# ) +# ] +# ) +# ), +# ast.Identifier(location=L, name='x') +# ] +# ) +# assert parsed == expected, "Should handle complex blocks with mixed features" + +# def test_parser_function_calls_as_arguments() -> None: +# tokens = tokenize('f(g(h(1)), x + y(z))') +# parsed = parse(tokens) +# expected = ast.FunctionCall( +# location=L, +# name='f', +# args=[ +# ast.FunctionCall( +# location=L, +# name='g', +# args=[ +# ast.FunctionCall( +# location=L, +# name='h', +# args=[ast.Literal(location=L, value=1)] +# ) +# ] +# ), +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='+', +# right=ast.FunctionCall( +# location=L, +# name='y', +# args=[ast.Identifier(location=L, name='z')] +# ) +# ) +# ] +# ) +# assert parsed == expected, "Should handle nested function calls as arguments" + +# def test_parser_assignment_in_condition() -> None: +# tokens = tokenize('if (x = read()) > 0 then print(x) else 0') +# parsed = parse(tokens) +# expected = ast.If( +# location=L, +# cond_exp=ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='=', +# right=ast.FunctionCall(location=L, name='read', args=[]) +# ), +# op='>', +# right=ast.Literal(location=L, value=0) +# ), +# then_exp=ast.FunctionCall(location=L, name='print', args=[ast.Identifier(location=L, name='x')]), +# else_exp=ast.Literal(location=L, value=0) +# ) +# assert parsed == expected, "Should allow assignments in conditional expressions" + +# def test_parser_chained_unary_operations() -> None: +# tokens = tokenize('not -(-x + y)') +# parsed = parse(tokens) +# expected = ast.UnaryOp( +# location=L, +# op='not', +# right=ast.UnaryOp( +# location=L, +# op='-', +# right=ast.BinaryOp( +# location=L, +# left=ast.UnaryOp( +# location=L, +# op='-', +# right=ast.Identifier(location=L, name='x') +# ), +# op='+', +# right=ast.Identifier(location=L, name='y') +# ) +# ) +# ) +# assert parsed == expected, "Should handle complex unary operator combinations" + +# def test_parser_mixed_type_block_result() -> None: +# tokens = tokenize(''' +# { +# { a; { b; c } }; +# { d() }; +# e +# } +# ''') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Block( +# location=L, +# statements=[ +# ast.Identifier(location=L, name='a'), +# ast.Block( +# location=L, +# statements=[ +# ast.Identifier(location=L, name='b'), +# ast.Identifier(location=L, name='c') +# ] +# ) +# ] +# ), +# ast.Block( +# location=L, +# statements=[ast.FunctionCall(location=L, name='d', args=[])] +# ), +# ast.Identifier(location=L, name='e') +# ] +# ) +# assert parsed == expected, "Should handle nested blocks with mixed statement types" + +# def test_parser_complex_operator_precedence() -> None: +# tokens = tokenize('a + b * c == d and not e or f % g <= 5') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='+', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='*', +# right=ast.Identifier(location=L, name='c') +# ) +# ), +# op='==', +# right=ast.Identifier(location=L, name='d') +# ), +# op='and', +# right=ast.UnaryOp( +# location=L, +# op='not', +# right=ast.Identifier(location=L, name='e') +# ) +# ), +# op='or', +# right=ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='f'), +# op='%', +# right=ast.Identifier(location=L, name='g') +# ), +# op='<=', +# right=ast.Literal(location=L, value=5) +# ) +# ) +# assert parsed == expected, "Should respect operator precedence hierarchy" + +# def test_parser_nested_arithmetic_and_comparison() -> None: +# tokens = tokenize('x * y + z < a - b % c') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='*', +# right=ast.Identifier(location=L, name='y') +# ), +# op='+', +# right=ast.Identifier(location=L, name='z') +# ), +# op='<', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='-', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='%', +# right=ast.Identifier(location=L, name='c') +# ) +# ) +# ) +# assert parsed == expected, "Should prioritize % then */- then + then <" + +# def test_parser_logical_ops_with_comparisons() -> None: +# tokens = tokenize('a == b and c != d or e < f') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='==', +# right=ast.Identifier(location=L, name='b') +# ), +# op='and', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='c'), +# op='!=', +# right=ast.Identifier(location=L, name='d') +# ) +# ), +# op='or', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='e'), +# op='<', +# right=ast.Identifier(location=L, name='f') +# ) +# ) +# assert parsed == expected, "Should group comparisons first, then and/or" + +# def test_parser_right_associative_assignment() -> None: +# tokens = tokenize('a = b = c + d * e') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='c'), +# op='+', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='d'), +# op='*', +# right=ast.Identifier(location=L, name='e') +# ) +# ) +# ) +# ) +# assert parsed == expected, "Assignment should be right-associative" + +# def test_parser_unary_ops_with_arithmetic() -> None: +# tokens = tokenize('not -x * y') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.UnaryOp( +# location=L, +# op='not', +# right=ast.UnaryOp( +# location=L, +# op='-', +# right=ast.Identifier(location=L, name='x') +# ) +# ), +# op='*', +# right=ast.Identifier(location=L, name='y') +# ) +# assert parsed == expected + +# def test_parser_function_calls_in_expressions() -> None: +# tokens = tokenize('a + f(b * c) >= g(d, e) and not h()') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='a'), +# op='+', +# right=ast.FunctionCall( +# location=L, +# name='f', +# args=[ +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='*', +# right=ast.Identifier(location=L, name='c') +# ) +# ] +# ) +# ), +# op='>=', +# right=ast.FunctionCall( +# location=L, +# name='g', +# args=[ +# ast.Identifier(location=L, name='d'), +# ast.Identifier(location=L, name='e') +# ] +# ) +# ), +# op='and', +# right=ast.UnaryOp( +# location=L, +# op='not', +# right=ast.FunctionCall( +# location=L, +# name='h', +# args=[] +# ) +# ) +# ) +# assert parsed == expected, "Should handle function calls within expressions" + +# def test_parser_locations_literal() -> None: +# tokens = tokenize('123') +# parsed = parse(tokens) +# expected_loc = SourceLocation(file='', line=0, column=0) +# assert parsed == ast.Literal(location=expected_loc, value=123) + +# def test_parser_locations_identifier() -> None: +# tokens = tokenize('x') +# parsed = parse(tokens) +# expected_loc = SourceLocation(file='', line=0, column=0) +# assert parsed == ast.Identifier(location=expected_loc, name='x') + +# def test_parser_locations_binary_op() -> None: +# tokens = tokenize('1 + 2') +# parsed = parse(tokens) +# op_loc = SourceLocation(file='', line=0, column=2) +# left_loc = SourceLocation(file='', line=0, column=0) +# right_loc = SourceLocation(file='', line=0, column=4) +# assert parsed == ast.BinaryOp( +# location=op_loc, +# left=ast.Literal(location=left_loc, value=1), +# op='+', +# right=ast.Literal(location=right_loc, value=2) +# ) + +# def test_parser_locations_function_call() -> None: +# tokens = tokenize('f(x)') +# parsed = parse(tokens) +# f_loc = SourceLocation(file='', line=0, column=0) +# x_loc = SourceLocation(file='', line=0, column=2) +# assert parsed == ast.FunctionCall( +# location=f_loc, +# name='f', +# args=[ast.Identifier(location=x_loc, name='x')] +# ) + +# def test_parser_locations_unary_op() -> None: +# tokens = tokenize('-x') +# parsed = parse(tokens) +# op_loc = SourceLocation(file='', line=0, column=0) +# x_loc = SourceLocation(file='', line=0, column=1) +# assert parsed == ast.UnaryOp( +# location=op_loc, +# op='-', +# right=ast.Identifier(location=x_loc, name='x') +# ) + +# def test_parser_locations_if_statement() -> None: +# tokens = tokenize('if a then b else c') +# parsed = parse(tokens) +# if_loc = SourceLocation(file='', line=0, column=0) +# a_loc = SourceLocation(file='', line=0, column=3) +# b_loc = SourceLocation(file='', line=0, column=10) +# c_loc = SourceLocation(file='', line=0, column=17) +# assert parsed == ast.If( +# location=if_loc, +# cond_exp=ast.Identifier(location=a_loc, name='a'), +# then_exp=ast.Identifier(location=b_loc, name='b'), +# else_exp=ast.Identifier(location=c_loc, name='c') +# ) + +# def test_parser_locations_nested_function_calls() -> None: +# tokens = tokenize('f(g(h(1)), 2 + x)') +# parsed = parse(tokens) +# assert parsed == ast.FunctionCall( +# location=SourceLocation(file='', line=0, column=0), +# name='f', +# args=[ +# ast.FunctionCall( +# location=SourceLocation(file='', line=0, column=2), +# name='g', +# args=[ +# ast.FunctionCall( +# location=SourceLocation(file='', line=0, column=4), +# name='h', +# args=[ast.Literal( +# location=SourceLocation(file='', line=0, column=6), +# value=1 +# )] +# ) +# ] +# ), +# ast.BinaryOp( +# location=SourceLocation(file='', line=0, column=12), +# left=ast.Literal( +# location=SourceLocation(file='', line=0, column=10), +# value=2 +# ), +# op='+', +# right=ast.Identifier( +# location=SourceLocation(file='', line=0, column=14), +# name='x' +# ) +# ) +# ] +# ) + +# def test_parser_locations_assignment_chain() -> None: +# tokens = tokenize('a = b = c + d') +# parsed = parse(tokens) +# assert parsed == ast.BinaryOp( +# location=SourceLocation(file='', line=0, column=2), +# left=ast.Identifier( +# location=SourceLocation(file='', line=0, column=0), +# name='a' +# ), +# op='=', +# right=ast.BinaryOp( +# location=SourceLocation(file='', line=0, column=6), +# left=ast.Identifier( +# location=SourceLocation(file='', line=0, column=4), +# name='b' +# ), +# op='=', +# right=ast.BinaryOp( +# location=SourceLocation(file='', line=0, column=10), +# left=ast.Identifier( +# location=SourceLocation(file='', line=0, column=8), +# name='c' +# ), +# op='+', +# right=ast.Identifier( +# location=SourceLocation(file='', line=0, column=12), +# name='d' +# ) +# ) +# ) +# ) + +# def test_parser_locations_multiline_nested_if() -> None: +# tokens = tokenize('''if +# x < 10 then +# if y then 1 +# else 2 +# else 3''') +# parsed = parse(tokens) +# assert parsed == ast.If( +# location=SourceLocation(file='', line=0, column=0), +# cond_exp=ast.BinaryOp( +# location=SourceLocation(file='', line=1, column=4), +# left=ast.Identifier( +# location=SourceLocation(file='', line=1, column=2), +# name='x' +# ), +# op='<', +# right=ast.Literal( +# location=SourceLocation(file='', line=1, column=6), +# value=10 +# ) +# ), +# then_exp=ast.If( +# location=SourceLocation(file='', line=2, column=4), +# cond_exp=ast.Identifier( +# location=SourceLocation(file='', line=2, column=7), +# name='y' +# ), +# then_exp=ast.Literal( +# location=SourceLocation(file='', line=2, column=12), +# value=1 +# ), +# else_exp=ast.Literal( +# location=SourceLocation(file='', line=3, column=9), +# value=2 +# ) +# ), +# else_exp=ast.Literal( +# location=SourceLocation(file='', line=4, column=7), +# value=3 +# ) +# ) + +# def test_parser_locations_complex_assignment_in_condition() -> None: +# tokens = tokenize('if (x = read()) > 0 then print(x + 1)') +# parsed = parse(tokens) +# assert parsed == ast.If( +# location=SourceLocation(file='', line=0, column=0), +# cond_exp=ast.BinaryOp( +# location=SourceLocation(file='', line=0, column=13), +# left=ast.BinaryOp( +# location=SourceLocation(file='', line=0, column=6), +# left=ast.Identifier( +# location=SourceLocation(file='', line=0, column=4), +# name='x' +# ), +# op='=', +# right=ast.FunctionCall( +# location=SourceLocation(file='', line=0, column=9), +# name='read', +# args=[] +# ) +# ), +# op='>', +# right=ast.Literal( +# location=SourceLocation(file='', line=0, column=15), +# value=0 +# ) +# ), +# then_exp=ast.FunctionCall( +# location=SourceLocation(file='', line=0, column=22), +# name='print', +# args=[ +# ast.BinaryOp( +# location=SourceLocation(file='', line=0, column=30), +# left=ast.Identifier( +# location=SourceLocation(file='', line=0, column=28), +# name='x' +# ), +# op='+', +# right=ast.Literal( +# location=SourceLocation(file='', line=0, column=32), +# value=1 +# ) +# ) +# ] +# ), +# else_exp=None +# ) + +# def test_parser_locations_complex_while_block() -> None: +# tokens = tokenize('''while x > 0 do { +# x = x - 1; +# print(x) +# }''') +# parsed = parse(tokens) +# assert parsed == ast.While( +# location=SourceLocation(file='', line=0, column=0), +# while_exp=ast.BinaryOp( +# location=SourceLocation(file='', line=0, column=8), +# left=ast.Identifier( +# location=SourceLocation(file='', line=0, column=6), +# name='x' +# ), +# op='>', +# right=ast.Literal( +# location=SourceLocation(file='', line=0, column=10), +# value=0 +# ) +# ), +# do_exp=ast.Block( +# location=SourceLocation(file='', line=0, column=14), +# statements=[ +# ast.BinaryOp( +# location=SourceLocation(file='', line=1, column=8), +# left=ast.Identifier( +# location=SourceLocation(file='', line=1, column=4), +# name='x' +# ), +# op='=', +# right=ast.BinaryOp( +# location=SourceLocation(file='', line=1, column=12), +# left=ast.Identifier( +# location=SourceLocation(file='', line=1, column=10), +# name='x' +# ), +# op='-', +# right=ast.Literal( +# location=SourceLocation(file='', line=1, column=14), +# value=1 +# ) +# ) +# ), +# ast.FunctionCall( +# location=SourceLocation(file='', line=2, column=4), +# name='print', +# args=[ast.Identifier( +# location=SourceLocation(file='', line=2, column=10), +# name='x' +# )] +# ) +# ] +# ) +# ) + +# def test_parser_top_level_single_expression() -> None: +# tokens = tokenize('a') +# parsed = parse(tokens) +# expected = ast.Identifier( +# location=SourceLocation(file='', line=0, column=0), +# name='a' +# ) + +# assert parsed == expected + +# def test_parser_top_level_two_expressions() -> None: +# tokens = tokenize('a; b') +# parsed = parse(tokens) +# expected = ast.Block( +# location=SourceLocation(file='', line=0, column=0), +# statements=[ +# ast.Identifier( +# location=SourceLocation(file='', line=0, column=0), +# name='a' +# ), +# ast.Identifier( +# location=SourceLocation(file='', line=0, column=3), +# name='b' +# ) +# ] +# ) +# assert parsed == expected + +# def test_parser_top_level_trailing_semicolon() -> None: +# tokens = tokenize('a; b;') +# parsed = parse(tokens) +# expected = ast.Block( +# location=SourceLocation(file='', line=0, column=0), +# statements=[ +# ast.Identifier( +# location=SourceLocation(file='', line=0, column=0), +# name='a' +# ), +# ast.Identifier( +# location=SourceLocation(file='', line=0, column=3), +# name='b' +# ), +# ast.Literal( +# location=SourceLocation(file='', line=0, column=5), +# value=None +# ) +# ] +# ) +# assert parsed == expected + +# def test_parser_top_level_mixed_types() -> None: +# tokens = tokenize(''' +# var x = 5; # this is a comment that should be ignored +# x = x + 1; +# print(x) +# ''') +# parsed = parse(tokens) +# expected = ast.Block( +# location=SourceLocation(file='', line=1, column=4), +# statements=[ +# ast.Var( +# location=SourceLocation(file='', line=1, column=4), +# name='x', +# value=ast.Literal( +# location=SourceLocation(file='', line=1, column=10), +# value=5 +# ), +# type_f=None +# ), +# ast.BinaryOp( +# location=SourceLocation(file='', line=2, column=7), +# left=ast.Identifier( +# location=SourceLocation(file='', line=2, column=4), +# name='x' +# ), +# op='=', +# right=ast.BinaryOp( +# location=SourceLocation(file='', line=2, column=11), +# left=ast.Identifier( +# location=SourceLocation(file='', line=2, column=9), +# name='x' +# ), +# op='+', +# right=ast.Literal( +# location=SourceLocation(file='', line=2, column=13), +# value=1 +# ) +# ) +# ), +# ast.FunctionCall( +# location=SourceLocation(file='', line=3, column=4), +# name='print', +# args=[ast.Identifier( +# location=SourceLocation(file='', line=3, column=10), +# name='x' +# )] +# ) +# ] +# ) +# assert parsed == expected + +# def test_parser_top_level_empty_semicolons() -> None: +# tokens = tokenize(';;') +# with pytest.raises(Exception) as exc_info: +# parse(tokens) +# assert 'expected "(", an integer literal or an identifier' in str(exc_info.value) + +# def test_parser_top_level_complex_block() -> None: +# tokens = tokenize(''' +# if true then x = 5; // blah blah blah +# while x > 0 do { +# print(x); +# x = x - 1 +# } +# ''') +# parsed = parse(tokens) +# expected = ast.Block( +# location=SourceLocation(file='', line=1, column=4), +# statements=[ +# ast.If( +# location=SourceLocation(file='', line=1, column=4), +# cond_exp=ast.Literal( +# location=SourceLocation(file='', line=1, column=7), +# value=True +# ), +# then_exp=ast.BinaryOp( +# location=SourceLocation(file='', line=1, column=16), +# left=ast.Identifier( +# location=SourceLocation(file='', line=1, column=14), +# name='x' +# ), +# op='=', +# right=ast.Literal( +# location=SourceLocation(file='', line=1, column=18), +# value=5 +# ) +# ), +# else_exp=None +# ), +# ast.While( +# location=SourceLocation(file='', line=2, column=4), +# while_exp=ast.BinaryOp( +# location=SourceLocation(file='', line=2, column=11), +# left=ast.Identifier( +# location=SourceLocation(file='', line=2, column=9), +# name='x' +# ), +# op='>', +# right=ast.Literal( +# location=SourceLocation(file='', line=2, column=13), +# value=0 +# ) +# ), +# do_exp=ast.Block( +# location=SourceLocation(file='', line=2, column=17), +# statements=[ +# ast.FunctionCall( +# location=SourceLocation(file='', line=3, column=6), +# name='print', +# args=[ast.Identifier( +# location=SourceLocation(file='', line=3, column=12), +# name='x' +# )] +# ), +# ast.BinaryOp( +# location=SourceLocation(file='', line=4, column=9), +# left=ast.Identifier( +# location=SourceLocation(file='', line=4, column=6), +# name='x' +# ), +# op='=', +# right=ast.BinaryOp( +# location=SourceLocation(file='', line=4, column=13), +# left=ast.Identifier( +# location=SourceLocation(file='', line=4, column=11), +# name='x' +# ), +# op='-', +# right=ast.Literal( +# location=SourceLocation(file='', line=4, column=15), +# value=1 +# ) +# ) +# ) +# ] +# ) +# ) +# ] +# ) +# assert parsed == expected + +# def test_parser_complex_block_with_var_and_nested_blocks() -> None: +# tokens = tokenize('a; b % g; { var c = f((5 + d) * f - e)}; x = 2; {{}};') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Identifier(location=L, name='a'), +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='b'), +# op='%', +# right=ast.Identifier(location=L, name='g') +# ), +# ast.Block( +# location=L, +# statements=[ +# ast.Var( +# location=L, +# name='c', +# value=ast.FunctionCall( +# location=L, +# name='f', +# args=[ +# ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Literal(location=L, value=5), +# op='+', +# right=ast.Identifier(location=L, name='d') +# ), +# op = '*', +# right=ast.Identifier(location=L, name='f')), +# op='-', +# right=ast.Identifier(location=L, name='e') +# ) +# ] +# ), +# type_f=None +# ) +# ] +# ), +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(location=L, name='x'), +# op='=', +# right=ast.Literal(location=L, value=2) +# ), +# ast.Block( +# location=L, +# statements=[ +# ast.Block( +# location=L, +# statements=[] +# ) +# ] +# ), +# ast.Literal(location=L, value=None) +# ] +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# def test_parser_if_expression_in_binary_op() -> None: +# tokens = tokenize('1 + if true then 2 else 3') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.Literal(location=L, value=1), +# op='+', +# right=ast.If( +# location=L, +# cond_exp=ast.Literal(location=L, value=True), +# then_exp=ast.Literal(location=L, value=2), +# else_exp=ast.Literal(location=L, value=3) +# ) +# ) +# assert parsed == expected, f'Expected {expected}, got {parsed}' + +# # def test_parser_complex_typed_var_with_locations() -> None: +# # code = 'var obj: ((a, b, c) => T, apple, banana, coconut) => () => T = 5' +# # tokens = tokenize(code) +# # parsed = parse(tokens) + +# # # Build expected locations based on token positions +# # expected = ast.Var( +# # location=SourceLocation(file='', line=0, column=0), # 'var' position +# # name='obj', +# # type=ast.Type( +# # # location=SourceLocation(file='', line=0, column=9), # First '(' +# # params=[ +# # # (a, b, c) => T +# # ast.Type( +# # # location=SourceLocation(file='', line=0, column=10), # Inner '(' +# # params=[ +# # ast.Identifier(SourceLocation(file='', line=0, column=11), 'a'), +# # ast.Identifier(SourceLocation(file='', line=0, column=14), 'b'), +# # ast.Identifier(SourceLocation(file='', line=0, column=17), 'c') +# # ], +# # result=ast.Identifier(SourceLocation(file='', line=0, column=23), 'T') +# # ), +# # ast.Identifier(SourceLocation(file='', line=0, column=26), 'apple'), +# # ast.Identifier(SourceLocation(file='', line=0, column=33), 'banana'), +# # ast.Identifier(SourceLocation(file='', line=0, column=41), 'coconut') +# # ], +# # result=ast.Type( +# # # location=SourceLocation(file='', line=0, column=53), # Empty param '(' +# # params=[], +# # result=ast.Identifier(SourceLocation(file='', line=0, column=59), 'T') +# # ) +# # ), +# # value=ast.Literal(SourceLocation(file='', line=0, column=63), 5) +# # ) + +# # assert parsed == expected, f''' +# # Expected locations: +# # {expected} + +# # Actual locations: +# # {parsed} +# # ''' + +# def test_parser_literals_and_identifiers() -> None: +# tokens = tokenize('42 true false x') +# with pytest.raises(Exception): +# parse(tokens) + +# def test_parser_binary_ops_precedence() -> None: +# tokens = tokenize('a + b * c - d / e') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(L, 'a'), +# op='+', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(L, 'b'), +# op='*', +# right=ast.Identifier(L, 'c') +# ) +# ), +# op='-', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(L, 'd'), +# op='/', +# right=ast.Identifier(L, 'e') +# ) +# ) +# assert parsed == expected + +# def test_parser_parentheses() -> None: +# tokens = tokenize('(a + b) * c') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(L, 'a'), +# op='+', +# right=ast.Identifier(L, 'b') +# ), +# op='*', +# right=ast.Identifier(L, 'c') +# ) +# assert parsed == expected + +# def test_parser_function_calls() -> None: +# tokens = tokenize('f(g(x, y + z), 42)') +# parsed = parse(tokens) +# expected = ast.FunctionCall( +# location=L, +# name='f', +# args=[ +# ast.FunctionCall( +# location=L, +# name='g', +# args=[ +# ast.Identifier(L, 'x'), +# ast.BinaryOp( +# location=L, +# left=ast.Identifier(L, 'y'), +# op='+', +# right=ast.Identifier(L, 'z') +# ) +# ] +# ), +# ast.Literal(L, 42) +# ] +# ) +# assert parsed == expected + +# def test_parser_if_expressions() -> None: +# tokens = tokenize('if a then b + c else x * y') +# parsed = parse(tokens) +# expected = ast.If( +# location=L, +# cond_exp=ast.Identifier(L, 'a'), +# then_exp=ast.BinaryOp( +# location=L, +# left=ast.Identifier(L, 'b'), +# op='+', +# right=ast.Identifier(L, 'c') +# ), +# else_exp=ast.BinaryOp( +# location=L, +# left=ast.Identifier(L, 'x'), +# op='*', +# right=ast.Identifier(L, 'y') +# ) +# ) +# assert parsed == expected + +# def test_parser_blocks_1() -> None: +# tokens = tokenize('{ a; { b } }') +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Identifier(L, 'a'), +# ast.Block( +# location=L, +# statements=[ast.Identifier(L, 'b')] +# ) +# ] +# ) +# assert parsed == expected + +# def test_parser_var_declarations() -> None: +# tokens = tokenize('var x = 5') +# parsed = parse(tokens) +# expected = ast.Var( +# location=L, +# name='x', +# value=ast.Literal(L, 5), +# type_f=None +# ) +# assert parsed == expected + +# def test_parser_comparison_ops() -> None: +# tokens = tokenize('a < b >= c') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(L, 'a'), +# op='<', +# right=ast.Identifier(L, 'b') +# ), +# op='>=', +# right=ast.Identifier(L, 'c') +# ) +# assert parsed == expected + +# def test_parser_assignment_right_associative() -> None: +# tokens = tokenize('a = b = c') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.Identifier(L, 'a'), +# op='=', +# right=ast.BinaryOp( +# location=L, +# left=ast.Identifier(L, 'b'), +# op='=', +# right=ast.Identifier(L, 'c') +# ) +# ) +# assert parsed == expected + +# def test_parser_locations() -> None: +# tokens = tokenize('1 + 2') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=SourceLocation(file='', line=0, column=2), +# left=ast.Literal(SourceLocation(file='', line=0, column=0), 1), +# op='+', +# right=ast.Literal(SourceLocation(file='', line=0, column=4), 2) +# ) +# assert parsed == expected + +# def test_parser_typed_var_declaration() -> None: +# tokens = tokenize('var f: (Int, Bool) => Int = expr') +# parsed = parse(tokens) +# expected = ast.Var( +# location=L, +# name='f', +# type_f=FunType( +# # location=L, +# params=[ +# Int_Instance, +# Bool_Instance +# ], +# result=Int_Instance +# ), +# value=ast.Identifier(L, 'expr') +# ) +# assert isinstance(parsed, ast.Var) +# assert parsed.type_f == expected.type_f + +# def test_parser_optional_semicolons() -> None: +# tokens = tokenize('{ a { b } }') +# with pytest.raises(Exception): +# parse(tokens) + +# def test_parser_error_missing_semicolon() -> None: +# tokens = tokenize('{ a b }') +# with pytest.raises(Exception): +# parse(tokens) + +# def test_parser_unary_ops_1() -> None: +# tokens = tokenize('not -x') +# parsed = parse(tokens) +# expected = ast.UnaryOp( +# location=L, +# op='not', +# right=ast.UnaryOp( +# location=L, +# op='-', +# right=ast.Identifier(L, 'x') +# ) +# ) +# assert parsed == expected + +# def test_parser_logical_ops() -> None: +# tokens = tokenize('a and b or c') +# parsed = parse(tokens) +# expected = ast.BinaryOp( +# location=L, +# left=ast.BinaryOp( +# location=L, +# left=ast.Identifier(L, 'a'), +# op='and', +# right=ast.Identifier(L, 'b') +# ), +# op='or', +# right=ast.Identifier(L, 'c') +# ) +# assert parsed == expected + +# def test_parser_misc() -> None: +# exp = """var x = 1; { var x = 2; x } x""" +# tokens = tokenize(exp) +# parsed = parse(tokens) +# expected = ast.Block( +# location=L, +# statements=[ +# ast.Var(location=L, name='x', value=ast.Literal(location=L, value=1), type_f=None), +# ast.Block( +# location=L, +# statements=[ +# ast.Var(location=L, name='x', value=ast.Literal(location=L, value=2), type_f=None), +# ast.Identifier(location=L, name='x') +# ] +# ), +# ast.Identifier(location=L, name='x') +# ] +# ) +# assert parsed == expected diff --git a/tests/tokenizer_test.py b/tests/tokenizer_test.py new file mode 100644 index 0000000..9e6272e --- /dev/null +++ b/tests/tokenizer_test.py @@ -0,0 +1,28 @@ +from compiler.tokenizer import tokenize, Token, L + +def test_tokenizer_basic() -> None: + assert tokenize('aaa 123 bbb') == [ + Token(location=L, type='identifier', text='aaa'), + Token(location=L, type='int_literal', text='123'), + Token(location=L, type='identifier', text='bbb'), + ] + +def test_tokenizer_newline() -> None: + assert tokenize('if 3\nwhile') == [ + Token(location=L, type='identifier', text='if'), + Token(location=L, type='int_literal', text='3'), + Token(location=L, type='identifier', text='while'), + ] + +def test_tokenizer_commments() -> None: + assert tokenize('aaa 123 bbb ; ) ( >= \n ) # aksdjalksjdkajskdjasd\n != // Another comment $') == [ + Token(location=L, type='identifier', text='aaa'), + Token(location=L, type='int_literal', text='123'), + Token(location=L, type='identifier', text='bbb'), + Token(location=L, type='punctuation', text=';'), + Token(location=L, type='punctuation', text=')'), + Token(location=L, type='punctuation', text='('), + Token(location=L, type='operator', text='>='), + Token(location=L, type='punctuation', text=')'), + Token(location=L, type='operator', text='!='), + ] diff --git a/tests/type_checker_test.py b/tests/type_checker_test.py new file mode 100644 index 0000000..76a05c2 --- /dev/null +++ b/tests/type_checker_test.py @@ -0,0 +1,921 @@ +import pytest +from compiler.type_checker import typecheck, TypeSymTab +from compiler import ast +from compiler.types import Int_Instance, Bool_Instance, Unit_Instance, FunType +from compiler.tokenizer import L + +def test_type_checker_assignment_with_unknown_variable() -> None: + node = ast.BinaryOp( + location=L, + left=ast.Identifier(L, 'x'), + op='=', + right=ast.Literal(L, 5) + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_literal_int() -> None: + node = ast.Literal(location=L, value=5) + result = typecheck(node) + assert result == Int_Instance + assert node.type == Int_Instance + +def test_type_checker_literal_bool() -> None: + node = ast.Literal(location=L, value=True) + result = typecheck(node) + assert result == Bool_Instance + assert node.type == Bool_Instance + +def test_type_checker_literal_unit() -> None: + node = ast.Literal(location=L, value=None) + result = typecheck(node) + assert result == Unit_Instance + assert node.type == Unit_Instance + +def test_type_checker_variable_lookup() -> None: + sym_tab = TypeSymTab(locals={'x': Int_Instance}) + node = ast.Identifier(location=L, name='x') + result = typecheck(node, sym_tab) + assert result == Int_Instance + assert node.type == Int_Instance + +def test_type_checker_variable_undefined() -> None: + node = ast.Identifier(location=L, name='x') + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_binary_op_add_valid() -> None: + node = ast.BinaryOp( + location=L, + left=ast.Literal(L, 1), + op='+', + right=ast.Literal(L, 2) + ) + result = typecheck(node) + assert result == Int_Instance + assert node.type == Int_Instance + +def test_type_checker_binary_op_add_invalid() -> None: + node = ast.BinaryOp( + location=L, + left=ast.Literal(L, True), + op='+', + right=ast.Literal(L, 2) + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_unary_op_negate_valid() -> None: + node = ast.UnaryOp( + location=L, + op='-', + right=ast.Literal(L, 5) + ) + result = typecheck(node) + assert result == Int_Instance + assert node.type == Int_Instance + +def test_type_checker_unary_op_not_valid() -> None: + node = ast.UnaryOp( + location=L, + op='not', + right=ast.Literal(L, False) + ) + result = typecheck(node) + assert result == Bool_Instance + assert node.type == Bool_Instance + +def test_type_checker_function_call_builtin() -> None: + node = ast.FunctionCall( + location=L, + name='print_int', + args=[ast.Literal(L, 5)] + ) + result = typecheck(node) + assert result == Unit_Instance + assert node.type == Unit_Instance + +def test_type_checker_block_last_expr() -> None: + node = ast.Block( + location=L, + statements=[ + ast.Literal(L, 1), + ast.Literal(L, 2) + ] + ) + result = typecheck(node) + assert result == Int_Instance + assert node.statements[-1].type == Int_Instance + +def test_type_checker_block_unit() -> None: + node = ast.Block( + location=L, + statements=[ + ast.Literal(L, 1), + ast.Literal(L, None) + ] + ) + result = typecheck(node) + assert result == Unit_Instance + assert node.statements[-1].type == Unit_Instance + +def test_type_checker_if_then_else_valid() -> None: + node = ast.If( + location=L, + cond_exp=ast.Literal(L, True), + then_exp=ast.Literal(L, 1), + else_exp=ast.Literal(L, 2) + ) + result = typecheck(node) + assert result == Int_Instance + assert node.then_exp.type == Int_Instance + +def test_type_checker_if_then_else_mismatch() -> None: + node = ast.If( + location=L, + cond_exp=ast.Literal(L, True), + then_exp=ast.Literal(L, 1), + else_exp=ast.Literal(L, False) + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_while_loop() -> None: + node = ast.While( + location=L, + while_exp=ast.Literal(L, True), + do_exp=ast.Literal(L, 5) + ) + result = typecheck(node) + assert result == Unit_Instance + assert node.type == Unit_Instance + +def test_type_checker_assignment_valid() -> None: + sym_tab = TypeSymTab(locals={'x': Int_Instance}) + node = ast.BinaryOp( + location=L, + left=ast.Identifier(L, 'x'), + op='=', + right=ast.Literal(L, 5) + ) + result = typecheck(node, sym_tab) + assert result == Int_Instance + assert node.right.type == Int_Instance + +def test_type_checker_assignment_mismatch() -> None: + sym_tab = TypeSymTab(locals={'x': Int_Instance}) + node = ast.BinaryOp( + location=L, + left=ast.Identifier(L, 'x'), + op='=', + right=ast.Literal(L, False) + ) + with pytest.raises(Exception): + typecheck(node, sym_tab) + +def test_type_checker_var_declaration_inferred() -> None: + node = ast.Var( + location=L, + name='x', + value=ast.Literal(L, 5), + type_f=None + ) + result = typecheck(node) + assert result == Int_Instance + assert node.value.type == Int_Instance + +def test_type_checker_var_declaration_typed_valid() -> None: + node = ast.Var( + location=L, + name='x', + value=ast.Literal(L, 5), + type_f=Int_Instance + ) + result = typecheck(node) + assert result == Int_Instance + assert node.value.type == Int_Instance + +def test_type_checker_var_declaration_typed_mismatch() -> None: + node = ast.Var( + location=L, + name='x', + value=ast.Literal(L, True), + type_f=Int_Instance + ) + with pytest.raises(Exception): + typecheck(node) + +# def test_type_checker_function_type_params() -> None: +# node = ast.FunctionCall( +# location=L, +# name='+', +# args=[ +# ast.Literal(L, 1), +# ast.Literal(L, 2) +# ] +# ) +# result = typecheck(node) +# assert result == Int_Instance +# assert node.type == Int_Instance + +def test_type_checker_symbol_table_hierarchy() -> None: + parent = TypeSymTab(locals={'x': Int_Instance}) + child = TypeSymTab(parent=parent) + node = ast.Identifier(L, 'x') + result = typecheck(node, child) + assert result == Int_Instance + assert node.type == Int_Instance + +def test_type_checker_equality_operator_valid() -> None: + node = ast.BinaryOp( + location=L, + left=ast.Literal(L, 5), + op='==', + right=ast.Literal(L, 5) + ) + result = typecheck(node) + assert result == Bool_Instance + assert node.type == Bool_Instance + +def test_type_checker_equality_operator_mismatch() -> None: + node = ast.BinaryOp( + location=L, + left=ast.Literal(L, 5), + op='==', + right=ast.Literal(L, False) + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_function_call_user_defined() -> None: + sym_tab = TypeSymTab(locals={ + 'f': FunType(params=[Int_Instance], result=Bool_Instance) + }) + node = ast.FunctionCall( + location=L, + name='f', + args=[ast.Literal(L, 5)] + ) + result = typecheck(node, sym_tab) + assert result == Bool_Instance + assert node.type == Bool_Instance + +def test_type_checker_nested_blocks() -> None: + node = ast.Block( + location=L, + statements=[ + ast.Block( + location=L, + statements=[ + ast.Literal(L, 5), + ast.Block( + location=L, + statements=[ + ast.Literal(L, True) + ] + ) + ] + ) + ] + ) + result = typecheck(node) + assert result == Bool_Instance + assert isinstance(node.statements[0], ast.Block) + assert node.statements[0].statements[-1].type == Bool_Instance + +def test_type_checker_function_type_annotation() -> None: + node = ast.Var( + location=L, + name='f', + value=ast.Literal(L, 5), + type_f=FunType(params=[Int_Instance], result=Bool_Instance) + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_function_call_with_wrong_arg_count() -> None: + sym_tab = TypeSymTab(locals={ + 'f': FunType(params=[Int_Instance, Int_Instance], result=Bool_Instance) + }) + node = ast.FunctionCall( + location=L, + name='f', + args=[ast.Literal(L, 5)] + ) + with pytest.raises(Exception): + typecheck(node, sym_tab) + +def test_type_checker_function_call_with_wrong_arg_type() -> None: + sym_tab = TypeSymTab(locals={ + 'f': FunType(params=[Int_Instance], result=Bool_Instance) + }) + node = ast.FunctionCall( + location=L, + name='f', + args=[ast.Literal(L, True)] + ) + with pytest.raises(Exception): + typecheck(node, sym_tab) + +def test_type_checker_function_call_with_nested_args() -> None: + sym_tab = TypeSymTab(locals={ + 'f': FunType(params=[Int_Instance, Bool_Instance], result=Bool_Instance) + }) + node = ast.FunctionCall( + location=L, + name='f', + args=[ + ast.Literal(L, 5), + ast.BinaryOp( + location=L, + left=ast.Literal(L, True), + op='and', + right=ast.Literal(L, False) + ) + ] + ) + result = typecheck(node, sym_tab) + assert result == Bool_Instance + assert node.type == Bool_Instance + +def test_type_checker_function_call_with_builtin_print_int() -> None: + node = ast.FunctionCall( + location=L, + name='print_int', + args=[ast.Literal(L, 5)] + ) + result = typecheck(node) + assert result == Unit_Instance + assert node.type == Unit_Instance + +def test_type_checker_function_call_with_builtin_read_int() -> None: + node = ast.FunctionCall( + location=L, + name='read_int', + args=[] + ) + result = typecheck(node) + assert result == Int_Instance + assert node.type == Int_Instance + +def test_type_checker_function_call_with_builtin_print_bool() -> None: + node = ast.FunctionCall( + location=L, + name='print_bool', + args=[ast.Literal(L, True)] + ) + result = typecheck(node) + assert result == Unit_Instance + assert node.type == Unit_Instance + +def test_type_checker_function_call_with_unknown_function() -> None: + node = ast.FunctionCall( + location=L, + name='unknown_func', + args=[ast.Literal(L, 5)] + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_if_then_without_else() -> None: + node = ast.If( + location=L, + cond_exp=ast.Literal(L, True), + then_exp=ast.Literal(L, 5), + else_exp=None + ) + result = typecheck(node) + assert result == Unit_Instance + assert node.type == Unit_Instance + +def test_type_checker_if_then_with_else() -> None: + node = ast.If( + location=L, + cond_exp=ast.Literal(L, True), + then_exp=ast.Literal(L, 5), + else_exp=ast.Literal(L, 10) + ) + result = typecheck(node) + assert result == Int_Instance + assert node.type == Int_Instance + +def test_type_checker_if_then_with_else_mismatch() -> None: + node = ast.If( + location=L, + cond_exp=ast.Literal(L, True), + then_exp=ast.Literal(L, 5), + else_exp=ast.Literal(L, False) + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_while_loop_with_non_bool_condition() -> None: + node = ast.While( + location=L, + while_exp=ast.Literal(L, 5), + do_exp=ast.Literal(L, 10) + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_while_loop_with_unit_body() -> None: + node = ast.While( + location=L, + while_exp=ast.Literal(L, True), + do_exp=ast.Literal(L, None) + ) + result = typecheck(node) + assert result == Unit_Instance + assert node.type == Unit_Instance + +def test_type_checker_block_with_multiple_statements() -> None: + node = ast.Block( + location=L, + statements=[ + ast.Literal(L, 5), + ast.Literal(L, True), + ast.Literal(L, None) + ] + ) + result = typecheck(node) + assert result == Unit_Instance + assert node.type == Unit_Instance + +def test_type_checker_block_with_nested_blocks() -> None: + node = ast.Block( + location=L, + statements=[ + ast.Block( + location=L, + statements=[ + ast.Literal(L, 5), + ast.Block( + location=L, + statements=[ + ast.Literal(L, True) + ] + ) + ] + ) + ] + ) + result = typecheck(node) + assert result == Bool_Instance + assert node.type == Bool_Instance + +def test_type_checker_block_with_empty_statements() -> None: + node = ast.Block( + location=L, + statements=[] + ) + result = typecheck(node) + assert result == Unit_Instance + assert node.type == Unit_Instance + +def test_type_checker_var_declaration_with_nested_expression() -> None: + node = ast.Var( + location=L, + name='x', + value=ast.BinaryOp( + location=L, + left=ast.Literal(L, 5), + op='+', + right=ast.Literal(L, 10) + ), + type_f=None + ) + result = typecheck(node) + assert result == Int_Instance + assert node.value.type == Int_Instance + +def test_type_checker_var_declaration_with_typed_initializer() -> None: + node = ast.Var( + location=L, + name='x', + value=ast.Literal(L, 5), + type_f=Int_Instance + ) + result = typecheck(node) + assert result == Int_Instance + assert node.value.type == Int_Instance + +def test_type_checker_var_declaration_with_typed_initializer_mismatch() -> None: + node = ast.Var( + location=L, + name='x', + value=ast.Literal(L, True), + type_f=Int_Instance + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_var_declaration_with_untyped_initializer() -> None: + node = ast.Var( + location=L, + name='x', + value=ast.Literal(L, 5), + type_f=None + ) + result = typecheck(node) + assert result == Int_Instance + assert node.value.type == Int_Instance + +def test_type_checker_var_declaration_with_function_type() -> None: + node = ast.Var( + location=L, + name='f', + value=ast.Literal(L, 5), + type_f=FunType(params=[Int_Instance], result=Bool_Instance) + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_assignment_with_function_call() -> None: + sym_tab = TypeSymTab(locals={'x': Int_Instance}) + node = ast.BinaryOp( + location=L, + left=ast.Identifier(L, 'x'), + op='=', + right=ast.FunctionCall( + location=L, + name='read_int', + args=[] + ) + ) + result = typecheck(node, sym_tab) + assert result == Int_Instance + assert node.right.type == Int_Instance + +def test_type_checker_assignment_with_function_call_mismatch() -> None: + sym_tab = TypeSymTab(locals={'x': Int_Instance}) + node = ast.BinaryOp( + location=L, + left=ast.Identifier(L, 'x'), + op='=', + right=ast.FunctionCall( + location=L, + name='print_int', + args=[ast.Literal(L, 5)] + ) + ) + with pytest.raises(Exception): + typecheck(node, sym_tab) + +def test_type_checker_assignment_with_nested_expression() -> None: + sym_tab = TypeSymTab(locals={'x': Int_Instance}) + node = ast.BinaryOp( + location=L, + left=ast.Identifier(L, 'x'), + op='=', + right=ast.BinaryOp( + location=L, + left=ast.Literal(L, 5), + op='+', + right=ast.Literal(L, 10) + ) + ) + result = typecheck(node, sym_tab) + assert result == Int_Instance + assert node.right.type == Int_Instance + +def test_type_checker_assignment_with_nested_expression_mismatch() -> None: + sym_tab = TypeSymTab(locals={'x': Int_Instance}) + node = ast.BinaryOp( + location=L, + left=ast.Identifier(L, 'x'), + op='=', + right=ast.BinaryOp( + location=L, + left=ast.Literal(L, 5), + op='+', + right=ast.Literal(L, True) + ) + ) + with pytest.raises(Exception): + typecheck(node, sym_tab) + +def test_type_checker_assignment_with_wrong_type() -> None: + sym_tab = TypeSymTab(locals={'x': Int_Instance}) + node = ast.BinaryOp( + location=L, + left=ast.Identifier(L, 'x'), + op='=', + right=ast.Literal(L, True) + ) + with pytest.raises(Exception): + typecheck(node, sym_tab) + +def test_type_checker_assignment_with_function_type() -> None: + sym_tab = TypeSymTab(locals={'f': FunType(params=[Int_Instance], result=Bool_Instance)}) + node = ast.BinaryOp( + location=L, + left=ast.Identifier(L, 'f'), + op='=', + right=ast.Literal(L, 5) + ) + with pytest.raises(Exception): + typecheck(node, sym_tab) + +def test_type_checker_assignment_with_function_call_returning_function() -> None: + sym_tab = TypeSymTab(locals={ + 'f': FunType(params=[], result=FunType(params=[Int_Instance], result=Bool_Instance)) + }) + node = ast.BinaryOp( + location=L, + left=ast.Identifier(L, 'x'), + op='=', + right=ast.FunctionCall( + location=L, + name='f', + args=[] + ) + ) + with pytest.raises(Exception): + typecheck(node, sym_tab) + +def test_type_checker_assignment_with_function_call_returning_int() -> None: + sym_tab = TypeSymTab(locals={ + 'f': FunType(params=[], result=Int_Instance), + 'x': Int_Instance + }) + node = ast.BinaryOp( + location=L, + left=ast.Identifier(L, 'x'), + op='=', + right=ast.FunctionCall( + location=L, + name='f', + args=[] + ) + ) + result = typecheck(node, sym_tab) + assert result == Int_Instance + assert node.right.type == Int_Instance + +def test_type_checker_assignment_with_function_call_returning_bool() -> None: + sym_tab = TypeSymTab(locals={ + 'f': FunType(params=[], result=Bool_Instance), + 'x': Int_Instance + }) + node = ast.BinaryOp( + location=L, + left=ast.Identifier(L, 'x'), + op='=', + right=ast.FunctionCall( + location=L, + name='f', + args=[] + ) + ) + with pytest.raises(Exception): + typecheck(node, sym_tab) + +import pytest +from compiler.type_checker import typecheck, TypeSymTab +from compiler import ast +from compiler.types import Int_Instance, Bool_Instance, Unit_Instance, FunType +from compiler.tokenizer import L + +def test_type_checker_var_typed_declaration_valid() -> None: + node = ast.Var( + location=L, + name='x', + value=ast.Literal(L, 5), + type_f=Int_Instance + ) + result = typecheck(node) + assert result == Int_Instance + assert node.value.type == Int_Instance + +def test_type_checker_var_typed_declaration_mismatch() -> None: + node = ast.Var( + location=L, + name='x', + value=ast.Literal(L, True), + type_f=Int_Instance + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_assignment_non_identifier() -> None: + node = ast.BinaryOp( + location=L, + left=ast.Literal(L, 5), # Invalid left side + op='=', + right=ast.Literal(L, 10) + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_user_function_call_args_count_mismatch() -> None: + sym_tab = TypeSymTab(locals={ + 'f': FunType(params=[Int_Instance, Int_Instance], result=Int_Instance) + }) + node = ast.FunctionCall( + location=L, + name='f', + args=[ast.Literal(L, 5)] # Missing second arg + ) + with pytest.raises(Exception): + typecheck(node, sym_tab) + +def test_type_checker_user_function_call_arg_type_mismatch() -> None: + sym_tab = TypeSymTab(locals={ + 'f': FunType(params=[Int_Instance], result=Bool_Instance) + }) + node = ast.FunctionCall( + location=L, + name='f', + args=[ast.Literal(L, True)] # Bool instead of Int + ) + with pytest.raises(Exception): + typecheck(node, sym_tab) + +def test_type_checker_if_then_returns_unit() -> None: + node = ast.If( + location=L, + cond_exp=ast.Literal(L, True), + then_exp=ast.Literal(L, 5), + else_exp=None + ) + result = typecheck(node) + assert result == Unit_Instance + +def test_type_checker_equality_mixed_types() -> None: + node = ast.BinaryOp( + location=L, + left=ast.Literal(L, 5), + op='==', + right=ast.Literal(L, False) + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_logical_op_non_boolean() -> None: + node = ast.BinaryOp( + location=L, + left=ast.Literal(L, 5), + op='and', + right=ast.Literal(L, 0) + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_nested_scope_shadowing() -> None: + outer = TypeSymTab(locals={'x': Int_Instance}) + inner = TypeSymTab(parent=outer) + inner.locals['x'] = Bool_Instance # Shadow outer x + + node = ast.Identifier(L, 'x') + result = typecheck(node, inner) + assert result == Bool_Instance + +def test_type_checker_function_type_multiple_params() -> None: + sym_tab = TypeSymTab(locals={ + 'f': FunType(params=[Int_Instance, Bool_Instance], result=Unit_Instance) + }) + node = ast.FunctionCall( + location=L, + name='f', + args=[ + ast.Literal(L, 5), + ast.Literal(L, True) + ] + ) + result = typecheck(node, sym_tab) + assert result == Unit_Instance + +def test_type_checker_read_int_builtin() -> None: + node = ast.FunctionCall( + location=L, + name='read_int', + args=[] + ) + result = typecheck(node) + assert result == Int_Instance + +def test_type_checker_block_trailing_semicolon_returns_unit() -> None: + node = ast.Block( + location=L, + statements=[ + ast.Literal(L, 5), + ast.Literal(L, None) # Trailing semicolon case + ] + ) + result = typecheck(node) + assert result == Unit_Instance + +def test_type_checker_not_operator_non_boolean() -> None: + node = ast.UnaryOp( + location=L, + op='not', + right=ast.Literal(L, 5) # Int instead of Bool + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_equality_operator_same_types() -> None: + for val in [5, True]: + node = ast.BinaryOp( + location=L, + left=ast.Literal(L, val), + op='==', + right=ast.Literal(L, val) + ) + result = typecheck(node) + assert result == Bool_Instance + +def test_type_checker_function_type_nested_params() -> None: + # Test function type with complex parameter structure + fun_type = FunType( + params=[[Int_Instance, Bool_Instance], Int_Instance], # From spec's TODO + result=Unit_Instance + ) + sym_tab = TypeSymTab(locals={'f': fun_type}) + + # Valid call with Int then Int (first param allows Int or Bool) + node = ast.FunctionCall( + location=L, + name='f', + args=[ast.Literal(L, 5), ast.Literal(L, 5)] + ) + with pytest.raises(Exception): # Second arg should be Int, but first allows Int + typecheck(node, sym_tab) # Actual logic may vary based on implementation + +def test_type_checker_operator_precedence_type_resolution() -> None: + # Ensure operator precedence doesn't affect type checking + node = ast.BinaryOp( + location=L, + left=ast.BinaryOp( + location=L, + left=ast.Literal(L, 3), + op='*', + right=ast.Literal(L, 4) + ), + op='+', + right=ast.Literal(L, 5) + ) + result = typecheck(node) + assert result == Int_Instance + +def test_type_checker_while_loop_condition_non_boolean() -> None: + node = ast.While( + location=L, + while_exp=ast.Literal(L, 5), # Non-bool condition + do_exp=ast.Literal(L, 0) + ) + with pytest.raises(Exception): + typecheck(node) + +def test_type_checker_builtin_print_bool_type() -> None: + node = ast.FunctionCall( + location=L, + name='print_bool', + args=[ast.Literal(L, True)] + ) + result = typecheck(node) + assert result == Unit_Instance + +def test_type_checker_function_return_type_mismatch() -> None: + sym_tab = TypeSymTab(locals={ + 'f': FunType(params=[], result=Int_Instance) + }) + node = ast.FunctionCall( + location=L, + name='f', + args=[] + ) + result = typecheck(node, sym_tab) + assert result == Int_Instance # Dummy test to be replaced with actual logic + +def test_type_checker_complex_function_type_annotation() -> None: + # Test type annotation like (Int, (Bool) => Unit) => Int + fun_type = FunType( + params=[ + Int_Instance, + FunType(params=[Bool_Instance], result=Unit_Instance) + ], + result=Int_Instance + ) + sym_tab = TypeSymTab(locals={'f': fun_type}) + + node = ast.FunctionCall( + location=L, + name='f', + args=[ + ast.Literal(L, 5), + ast.Identifier(L, 'g') # Assume 'g' has correct type + ] + ) + # Actual test would require 'g' to be in symbol table + with pytest.raises(Exception): # 'g' not defined here + typecheck(node, sym_tab) + +def test_type_checker_undefined_function_call() -> None: + node = ast.FunctionCall( + location=L, + name='undefined_func', + args=[] + ) + with pytest.raises(Exception): + typecheck(node)