This commit is contained in:
2026-06-24 17:24:04 +02:00
commit 00c38a12d9
41 changed files with 7289 additions and 0 deletions
+8
View File
@@ -0,0 +1,8 @@
.test-gadget
.*_cache
*.pyc
__pycache__
.git
exercises
program
program.src
+7
View File
@@ -0,0 +1,7 @@
*.pyc
__pycache__
/.mypy_cache
/.pytest_cache
/.test-gadget
program
program.src
+1
View File
@@ -0,0 +1 @@
3.12.7
+18
View File
@@ -0,0 +1,18 @@
FROM python:3.12-alpine3.21
RUN pip install --no-cache-dir poetry==2.0.0
RUN apk add --no-cache bash binutils
WORKDIR /compiler
COPY pyproject.toml poetry.lock .
RUN poetry install --no-root --no-cache
COPY . .
RUN poetry install --no-cache && ./check.sh
EXPOSE 3000
# Setting PYTHONUNBUFFERED forces print() to flush even if stdout is not a TTY.
ENV PYTHONUNBUFFERED=1
CMD ["./compiler.sh", "serve", "--host=0.0.0.0"]
+53
View File
@@ -0,0 +1,53 @@
## Setup
Requirements:
- [Pyenv](https://github.com/pyenv/pyenv) for installing Python 3.12+
- Recommended installation method: the "automatic installer"
i.e. `curl https://pyenv.run | bash`
- Follow the instructions at the end of the output to make pyenv available in your shell.
You may need to restart your shell or even log out and log in again to make
the `pyenv` command available.
- [Poetry](https://python-poetry.org/) for installing dependencies
- Recommended installation method: the "official installer"
i.e. `curl -sSL https://install.python-poetry.org | python3 -`
Install dependencies:
# Install Python specified in `.python-version`
pyenv install
# Install dependencies specified in `pyproject.toml`
poetry install
If `pyenv install` gives an error about `_tkinter`, you can ignore it.
If you see other errors, you may have to investigate.
If you have trouble with Poetry not picking up pyenv's python installation,
try `poetry env remove --all` and then `poetry install` again.
Typecheck and run local unit tests:
./check.sh
# or individually:
poetry run mypy .
poetry run pytest -vv
Once you've finished your compiler, edit `src/__main__.py` to call your compiler in function `call_compiler`.
Then you can run your compiler on a source code file like this:
./compiler.sh compile path/to/source/code --output=path/to/output/file
You can send the finished compiler to Test Gadget for evaluation with:
./test-gadget.py submit
See the course page for more information.
## IDE setup
Recommended VSCode extensions:
- Python
- Pylance
- autopep8
Executable
+6
View File
@@ -0,0 +1,6 @@
#!/bin/bash
set -euo pipefail
cd "$(dirname "${0}")"
poetry run mypy .
rm -Rf test_programs/workdir
poetry run pytest -vv tests/
Executable
+3
View File
@@ -0,0 +1,3 @@
#!/bin/bash
set -euo pipefail
exec poetry -C "$(dirname "${0}")" run main "$@"
+8
View File
@@ -0,0 +1,8 @@
import re
re_1 = r'hello!*'
re_2 = r'hello!* +there'
re_3 = r'(\+|\-|\*|\/)'
re_4 = r'\-?[0-9]+'
re_5 = r'\"([a-z]+ *)+\"'
re_6 = r'\"([[a-z]|\\.]+ *)+\"'
+22
View File
@@ -0,0 +1,22 @@
// a * b
Call(*, [a, b], x1)
// f(g(x + 1))
Call(+, [x, 1], x1)
Call(g, [x2], x2)
Call(f, [x3], x3)
// { f(x); f(y); }
Call(f, [x], x1)
Call(f, [y], x2)
// while a < b do f()
L0
Call(<, [a, b], x1)
CondJump(x1, L1, L2)
L1
Call(f, [], x2)
Jump(L0)
L2
+3
View File
@@ -0,0 +1,3 @@
addq %rax, %rbx
addq %rbx, %rcx
movq %rcx, 48(%rdx)
+25
View File
@@ -0,0 +1,25 @@
<mxfile host="Electron" modified="2025-01-20T21:13:46.282Z" agent="Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/23.1.5 Chrome/120.0.6099.109 Electron/28.1.0 Safari/537.36" etag="lBDD8vVqNHhn5PAcX9s0" version="23.1.5" type="device">
<diagram name="Page-1" id="qzKbqnHofdqwYui6Ryp2">
<mxGraphModel dx="1046" dy="613" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
<root>
<mxCell id="0" />
<mxCell id="1" parent="0" />
<mxCell id="uLvV7XOe-hvb0cQ6ErAa-1" value="a" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="320" y="280" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="uLvV7XOe-hvb0cQ6ErAa-5" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.25;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="uLvV7XOe-hvb0cQ6ErAa-2" target="uLvV7XOe-hvb0cQ6ErAa-1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="uLvV7XOe-hvb0cQ6ErAa-6" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.75;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="uLvV7XOe-hvb0cQ6ErAa-2" target="uLvV7XOe-hvb0cQ6ErAa-3">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="uLvV7XOe-hvb0cQ6ErAa-2" value="+" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="360" y="200" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="uLvV7XOe-hvb0cQ6ErAa-3" value="b" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="400" y="280" width="40" height="40" as="geometry" />
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>
BIN
View File
Binary file not shown.
+55
View File
@@ -0,0 +1,55 @@
# Metadata for debuggers and other tools
.global main
.type main, @function
.extern printf
.extern scanf
.section .text # Begins code and data
# Label that marks beginning of main function
main:
# Function stack setup
pushq %rbp
movq %rsp, %rbp
subq $128, %rsp # Reserve 128 bytes of stack space
# Read an integer into -8(%rbp)
movq $scan_format, %rdi # 1st param: "%ld"
leaq -8(%rbp), %rsi # 2nd param: (%rbp - 8)
callq scanf
# If the input was invalid, jump to end
cmpq $1, %rax
jne .Lend
# Read another integer into -16(%rbp)
movq $scan_format, %rdi
leaq -16(%rbp), %rsi
callq scanf
# If the input was invalid, jump to end
cmpq $1, %rax
jne .Lend
# Add the two integers together
movq -8(%rbp), %rax # Load first integer into %rax
addq -16(%rbp), %rax # Add second integer to %rax
# Call function 'printf("%ld\n", %rsi)'
# to print the number in %rsi.
movq $print_format, %rdi
movq %rax, %rsi
callq printf
# Labels starting with ".L" are local to this function,
# i.e. another function than "main" could have its own ".Lend".
.Lend:
# Return from main with status code 0
movq $0, %rax
movq %rbp, %rsp
popq %rbp
ret
# String data that we pass to functions 'scanf' and 'printf'
scan_format:
.asciz "%ld"
print_format:
.asciz "%ld\n"
+49
View File
@@ -0,0 +1,49 @@
<mxfile host="Electron" modified="2025-01-20T21:21:19.237Z" agent="Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/23.1.5 Chrome/120.0.6099.109 Electron/28.1.0 Safari/537.36" etag="SEX2oxCxuKndjGqDOJCb" version="23.1.5" type="device">
<diagram name="Page-1" id="qzKbqnHofdqwYui6Ryp2">
<mxGraphModel dx="1046" dy="613" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
<root>
<mxCell id="0" />
<mxCell id="1" parent="0" />
<mxCell id="xjGVHfltn5cPiIoTuc5V-1" value="c" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="480" y="240" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="xjGVHfltn5cPiIoTuc5V-4" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.25;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="xjGVHfltn5cPiIoTuc5V-2" target="xjGVHfltn5cPiIoTuc5V-3">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="xjGVHfltn5cPiIoTuc5V-5" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.75;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="xjGVHfltn5cPiIoTuc5V-2" target="xjGVHfltn5cPiIoTuc5V-1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="xjGVHfltn5cPiIoTuc5V-2" value="+" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="440" y="160" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="xjGVHfltn5cPiIoTuc5V-7" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="xjGVHfltn5cPiIoTuc5V-3" target="xjGVHfltn5cPiIoTuc5V-6">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="xjGVHfltn5cPiIoTuc5V-3" value="f" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="400" y="240" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="xjGVHfltn5cPiIoTuc5V-9" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.25;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="xjGVHfltn5cPiIoTuc5V-6" target="xjGVHfltn5cPiIoTuc5V-8">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="xjGVHfltn5cPiIoTuc5V-12" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.75;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="xjGVHfltn5cPiIoTuc5V-6" target="xjGVHfltn5cPiIoTuc5V-10">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="xjGVHfltn5cPiIoTuc5V-6" value="*" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="400" y="320" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="xjGVHfltn5cPiIoTuc5V-8" value="a" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="360" y="400" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="xjGVHfltn5cPiIoTuc5V-14" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="xjGVHfltn5cPiIoTuc5V-10" target="xjGVHfltn5cPiIoTuc5V-13">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="xjGVHfltn5cPiIoTuc5V-10" value="f" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="440" y="400" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="xjGVHfltn5cPiIoTuc5V-13" value="b" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="440" y="480" width="40" height="40" as="geometry" />
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>
+49
View File
@@ -0,0 +1,49 @@
<mxfile host="Electron" modified="2025-01-20T21:16:45.194Z" agent="Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/23.1.5 Chrome/120.0.6099.109 Electron/28.1.0 Safari/537.36" etag="LboNHYIToo15t9hWjeOL" version="23.1.5" type="device">
<diagram name="Page-1" id="qzKbqnHofdqwYui6Ryp2">
<mxGraphModel dx="1046" dy="613" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
<root>
<mxCell id="0" />
<mxCell id="1" parent="0" />
<mxCell id="uLvV7XOe-hvb0cQ6ErAa-1" value="a" style="rounded=0;whiteSpace=wrap;html=1;" parent="1" vertex="1">
<mxGeometry x="320" y="280" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="uLvV7XOe-hvb0cQ6ErAa-5" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.25;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" parent="1" source="uLvV7XOe-hvb0cQ6ErAa-2" target="uLvV7XOe-hvb0cQ6ErAa-1" edge="1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="uLvV7XOe-hvb0cQ6ErAa-6" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.75;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" parent="1" source="uLvV7XOe-hvb0cQ6ErAa-2" target="uLvV7XOe-hvb0cQ6ErAa-3" edge="1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="uLvV7XOe-hvb0cQ6ErAa-2" value="+" style="rounded=0;whiteSpace=wrap;html=1;" parent="1" vertex="1">
<mxGeometry x="360" y="200" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="uLvV7XOe-hvb0cQ6ErAa-3" value="b" style="rounded=0;whiteSpace=wrap;html=1;" parent="1" vertex="1">
<mxGeometry x="400" y="280" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="KTHXQ8Rk29yC3revGvwI-1" value="b" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="480" y="280" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="KTHXQ8Rk29yC3revGvwI-2" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.25;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="KTHXQ8Rk29yC3revGvwI-4" target="KTHXQ8Rk29yC3revGvwI-1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="KTHXQ8Rk29yC3revGvwI-3" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.75;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="KTHXQ8Rk29yC3revGvwI-4" target="KTHXQ8Rk29yC3revGvwI-5">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="KTHXQ8Rk29yC3revGvwI-4" value="+" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="520" y="200" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="KTHXQ8Rk29yC3revGvwI-5" value="c" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="560" y="280" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="KTHXQ8Rk29yC3revGvwI-8" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.25;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="KTHXQ8Rk29yC3revGvwI-6" target="uLvV7XOe-hvb0cQ6ErAa-2">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="KTHXQ8Rk29yC3revGvwI-9" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.75;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="KTHXQ8Rk29yC3revGvwI-6" target="KTHXQ8Rk29yC3revGvwI-4">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="KTHXQ8Rk29yC3revGvwI-6" value="f" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="440" y="120" width="40" height="40" as="geometry" />
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>
+25
View File
@@ -0,0 +1,25 @@
<mxfile host="Electron" modified="2025-01-20T21:18:12.329Z" agent="Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/23.1.5 Chrome/120.0.6099.109 Electron/28.1.0 Safari/537.36" etag="tFIp3-svkz9ZbBpyxJu4" version="23.1.5" type="device">
<diagram name="Page-1" id="qzKbqnHofdqwYui6Ryp2">
<mxGraphModel dx="1046" dy="613" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
<root>
<mxCell id="0" />
<mxCell id="1" parent="0" />
<mxCell id="uLvV7XOe-hvb0cQ6ErAa-1" value="a" style="rounded=0;whiteSpace=wrap;html=1;" parent="1" vertex="1">
<mxGeometry x="360" y="280" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="H9cBYIyhk6fI4PtsZMUV-1" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="uLvV7XOe-hvb0cQ6ErAa-2" target="uLvV7XOe-hvb0cQ6ErAa-1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="uLvV7XOe-hvb0cQ6ErAa-2" value="f" style="rounded=0;whiteSpace=wrap;html=1;" parent="1" vertex="1">
<mxGeometry x="360" y="200" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="H9cBYIyhk6fI4PtsZMUV-4" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="H9cBYIyhk6fI4PtsZMUV-2" target="uLvV7XOe-hvb0cQ6ErAa-2">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="H9cBYIyhk6fI4PtsZMUV-2" value="f" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="360" y="120" width="40" height="40" as="geometry" />
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>
@@ -0,0 +1,61 @@
<mxfile host="Electron" modified="2025-01-20T21:26:37.255Z" agent="Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/23.1.5 Chrome/120.0.6099.109 Electron/28.1.0 Safari/537.36" etag="5Sx5WTNgAyX2nzQTgUba" version="23.1.5" type="device">
<diagram name="Page-1" id="qzKbqnHofdqwYui6Ryp2">
<mxGraphModel dx="1046" dy="613" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
<root>
<mxCell id="0" />
<mxCell id="1" parent="0" />
<mxCell id="BO4O0CvvBerr2o1HbyMi-7" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.25;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="H9cBYIyhk6fI4PtsZMUV-2" target="BO4O0CvvBerr2o1HbyMi-1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-17" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.75;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="H9cBYIyhk6fI4PtsZMUV-2" target="BO4O0CvvBerr2o1HbyMi-13">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="H9cBYIyhk6fI4PtsZMUV-2" value="while" style="rounded=0;whiteSpace=wrap;html=1;" parent="1" vertex="1">
<mxGeometry x="360" y="120" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-5" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.25;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="BO4O0CvvBerr2o1HbyMi-1" target="BO4O0CvvBerr2o1HbyMi-2">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-6" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.75;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="BO4O0CvvBerr2o1HbyMi-1" target="BO4O0CvvBerr2o1HbyMi-3">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-1" value="&amp;lt;" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="280" y="200" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-2" value="i" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="240" y="280" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-3" value="100" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="320" y="280" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-8" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.25;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="BO4O0CvvBerr2o1HbyMi-10" target="BO4O0CvvBerr2o1HbyMi-11">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-9" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.75;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="BO4O0CvvBerr2o1HbyMi-10" target="BO4O0CvvBerr2o1HbyMi-12">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-10" value="+" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="480" y="280" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-11" value="i" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="440" y="360" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-12" value="1" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="520" y="360" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-15" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.25;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="BO4O0CvvBerr2o1HbyMi-13" target="BO4O0CvvBerr2o1HbyMi-14">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-16" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.75;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="BO4O0CvvBerr2o1HbyMi-13" target="BO4O0CvvBerr2o1HbyMi-10">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-13" value="=" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="440" y="200" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="BO4O0CvvBerr2o1HbyMi-14" value="i" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="400" y="280" width="40" height="40" as="geometry" />
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>
+3
View File
@@ -0,0 +1,3 @@
[mypy]
disallow_untyped_defs = True
disallow_untyped_calls = True
Generated
+173
View File
@@ -0,0 +1,173 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "autopep8"
version = "2.3.1"
description = "A tool that automatically formats Python code to conform to the PEP 8 style guide"
optional = false
python-versions = ">=3.8"
files = [
{file = "autopep8-2.3.1-py2.py3-none-any.whl", hash = "sha256:a203fe0fcad7939987422140ab17a930f684763bf7335bdb6709991dd7ef6c2d"},
{file = "autopep8-2.3.1.tar.gz", hash = "sha256:8d6c87eba648fdcfc83e29b788910b8643171c395d9c4bcf115ece035b9c9dda"},
]
[package.dependencies]
pycodestyle = ">=2.12.0"
[[package]]
name = "colorama"
version = "0.4.6"
description = "Cross-platform colored terminal text."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "iniconfig"
version = "2.0.0"
description = "brain-dead simple config-ini parsing"
optional = false
python-versions = ">=3.7"
files = [
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
]
[[package]]
name = "mypy"
version = "1.13.0"
description = "Optional static typing for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"},
{file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"},
{file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"},
{file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"},
{file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"},
{file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"},
{file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"},
{file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"},
{file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"},
{file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"},
{file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"},
{file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"},
{file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"},
{file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"},
{file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"},
{file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"},
{file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"},
{file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"},
{file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"},
{file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"},
{file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"},
{file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"},
{file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"},
{file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"},
{file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"},
{file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"},
{file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"},
{file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"},
{file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"},
{file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"},
{file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"},
{file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"},
]
[package.dependencies]
mypy-extensions = ">=1.0.0"
typing-extensions = ">=4.6.0"
[package.extras]
dmypy = ["psutil (>=4.0)"]
faster-cache = ["orjson"]
install-types = ["pip"]
mypyc = ["setuptools (>=50)"]
reports = ["lxml"]
[[package]]
name = "mypy-extensions"
version = "1.0.0"
description = "Type system extensions for programs checked with the mypy type checker."
optional = false
python-versions = ">=3.5"
files = [
{file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"},
{file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"},
]
[[package]]
name = "packaging"
version = "24.2"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"},
{file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"},
]
[[package]]
name = "pluggy"
version = "1.5.0"
description = "plugin and hook calling mechanisms for python"
optional = false
python-versions = ">=3.8"
files = [
{file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
{file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
]
[package.extras]
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "pycodestyle"
version = "2.12.1"
description = "Python style guide checker"
optional = false
python-versions = ">=3.8"
files = [
{file = "pycodestyle-2.12.1-py2.py3-none-any.whl", hash = "sha256:46f0fb92069a7c28ab7bb558f05bfc0110dac69a0cd23c61ea0040283a9d78b3"},
{file = "pycodestyle-2.12.1.tar.gz", hash = "sha256:6838eae08bbce4f6accd5d5572075c63626a15ee3e6f842df996bf62f6d73521"},
]
[[package]]
name = "pytest"
version = "8.3.3"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2"},
{file = "pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181"},
]
[package.dependencies]
colorama = {version = "*", markers = "sys_platform == \"win32\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=1.5,<2"
[package.extras]
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "typing-extensions"
version = "4.12.2"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
files = [
{file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"},
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
]
[metadata]
lock-version = "2.0"
python-versions = "^3.12"
content-hash = "b3bb39fa5b66f0bdb5a3c762a81c8fc01c30b0103c4ff37a063e4f2e32af0a6f"
+31
View File
@@ -0,0 +1,31 @@
[tool.poetry]
name = "compilers-project"
version = "0.0.0"
description = ""
authors = []
readme = "README.md"
packages = [{include = "compiler", from = "src", format = ["sdist"]}]
[tool.poetry.dependencies]
python = "^3.12"
[tool.poetry.group.dev.dependencies]
autopep8 = "^2.3.1"
mypy = "^1.13.0"
pytest = "^8.3.3"
[tool.poetry.scripts]
main = "compiler.__main__:main"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
pythonpath = "src"
addopts = [
"--import-mode=importlib",
]
[virtualenvs]
prefer-active-python = true
View File
+117
View File
@@ -0,0 +1,117 @@
from base64 import b64encode
import json
import re
import sys
from socketserver import ForkingTCPServer, StreamRequestHandler
from traceback import format_exception
from typing import Any
from compiler.tokenizer import tokenize
from compiler.parser import parse
from compiler.type_checker import typecheck_module
from compiler.ir_generator import generate_ir_from_module
from compiler.assembly_generator import generate_assembly_from_dict
from compiler.assembler import assemble_and_get_executable
def call_compiler(source_code: str, input_file_name: str) -> bytes:
# *** TODO ***
# Call your compiler here and return the compiled executable.
# Raise an exception on compilation error.
#
# The input file name is informational only: you can optionally include in your source locations and error messages,
# or you can ignore it.
# *** TODO ***
tokens = tokenize(source_code)
ast = parse(tokens)
typecheck_module(ast)
instructions = generate_ir_from_module(ast)
assembly_code = generate_assembly_from_dict(instructions)
print(assembly_code)
return assemble_and_get_executable(assembly_code)
def main() -> int:
# === Option parsing ===
command: str | None = None
input_file: str | None = None
output_file: str | None = None
host = "127.0.0.1"
port = 3000
for arg in sys.argv[1:]:
if (m := re.fullmatch(r'--output=(.+)', arg)) is not None:
output_file = m[1]
elif (m := re.fullmatch(r'--host=(.+)', arg)) is not None:
host = m[1]
elif (m := re.fullmatch(r'--port=(.+)', arg)) is not None:
port = int(m[1])
elif arg.startswith('-'):
raise Exception(f"Unknown argument: {arg}")
elif command is None:
command = arg
elif input_file is None:
input_file = arg
else:
raise Exception("Multiple input files not supported")
if command is None:
print(f"Error: command argument missing", file=sys.stderr)
return 1
def read_source_code() -> str:
if input_file is not None:
with open(input_file) as f:
return f.read()
else:
return sys.stdin.read()
# === Command implementations ===
if command == 'compile':
source_code = read_source_code()
if output_file is None:
raise Exception("Output file flag --output=... required")
executable = call_compiler(source_code, input_file or '(source code)')
with open(output_file, 'wb') as f:
f.write(executable)
elif command == 'serve':
try:
run_server(host, port)
except KeyboardInterrupt:
pass
else:
print(f"Error: unknown command: {command}", file=sys.stderr)
return 1
return 0
def run_server(host: str, port: int) -> None:
class Server(ForkingTCPServer):
allow_reuse_address = True
request_queue_size = 32
class Handler(StreamRequestHandler):
def handle(self) -> None:
result: dict[str, Any] = {}
try:
input_str = self.rfile.read().decode()
input = json.loads(input_str)
if input["command"] == "compile":
source_code = input["code"]
executable = call_compiler(source_code, "(source code)")
result["program"] = b64encode(executable).decode()
elif input["command"] == "ping":
pass
else:
result["error"] = "Unknown command: " + input['command']
except Exception as e:
result["error"] = "".join(format_exception(e))
result_str = json.dumps(result)
self.request.sendall(str.encode(result_str))
print(f"Starting TCP server at {host}:{port}")
with Server((host, port), Handler) as server:
server.serve_forever()
if __name__ == '__main__':
sys.exit(main())
+373
View File
@@ -0,0 +1,373 @@
import subprocess
import tempfile
from contextlib import nullcontext
from os import path
from typing import Any, Callable, ContextManager, TypeVar
import shutil
from pathlib import Path
T = TypeVar('T')
def assemble(
assembly_code: str,
output_file: str,
workdir: str | None = None,
tempfile_basename: str = 'program',
link_with_c: bool = False,
extra_libraries: list[str] = [],
) -> None:
"""Invokes 'as' and 'ld' to generate an executable file from Assembly code.
The file is written to the given path.
"""
_assemble(
assembly_code=assembly_code,
workdir=workdir,
tempfile_basename=tempfile_basename,
link_with_c=link_with_c,
extra_libraries=extra_libraries,
take_output=lambda f: shutil.move(f, output_file)
)
def assemble_and_get_executable(
assembly_code: str,
workdir: str | None = None,
tempfile_basename: str = 'program',
link_with_c: bool = False,
extra_libraries: list[str] = [],
) -> bytes:
"""Invokes 'as' and 'ld' to generate an executable file from Assembly code.
The file is returned.
"""
return _assemble(
assembly_code=assembly_code,
workdir=workdir,
tempfile_basename=tempfile_basename,
link_with_c=link_with_c,
extra_libraries=extra_libraries,
take_output=lambda f: Path(f).read_bytes()
)
def _assemble(
assembly_code: str,
workdir: str | None,
tempfile_basename: str,
link_with_c: bool,
extra_libraries: list[str],
take_output: Callable[[str], T],
) -> T:
if workdir is not None:
wd = Path(workdir).absolute().as_posix()
return _assemble_impl(assembly_code, wd, tempfile_basename, link_with_c, extra_libraries, take_output)
else:
with tempfile.TemporaryDirectory(prefix='compiler_') as wd:
return _assemble_impl(assembly_code, wd, tempfile_basename, link_with_c, extra_libraries, take_output)
def _assemble_impl(
assembly_code: str,
workdir: str,
tempfile_basename: str,
link_with_c: bool,
extra_libraries: list[str],
take_output: Callable[[str], T],
) -> T:
stdlib_asm = path.join(workdir, 'stdlib.s')
stdlib_obj = path.join(workdir, 'stdlib.o')
program_asm = path.join(workdir, f'{tempfile_basename}.s')
program_obj = path.join(workdir, f'{tempfile_basename}.o')
output_file = path.join(workdir, 'a.out')
if link_with_c:
final_stdlib_asm_code = drop_start_symbol(stdlib_asm_code)
else:
final_stdlib_asm_code = stdlib_asm_code
with open(stdlib_asm, 'w') as f:
f.write(final_stdlib_asm_code)
with open(program_asm, 'w') as f:
f.write(assembly_code)
subprocess.run(['as', '-g', '-o' +
stdlib_obj, stdlib_asm], check=True)
subprocess.run(['as', '-g', '-o' +
program_obj, program_asm], check=True)
linker_flags = ['-static', *[f'-l{lib}' for lib in extra_libraries]]
if link_with_c:
# Linking with the C standard library correctly is complicated,
# as evidenced by the complicated linker command shown by `cc -v something.c`.
# Instead of trying to build the right `ld` command ourselves, we use the C compiler
# to do the linking.
subprocess.run(
['cc', '-o' + output_file, *linker_flags, stdlib_obj, program_obj], check=True)
else:
subprocess.run(
['ld', '-o' + output_file, *linker_flags, stdlib_obj, program_obj], check=True)
return take_output(output_file)
def drop_start_symbol(code: str) -> str:
return code.split('# BEGIN START')[0] + code.split('# END START')[1]
# WARNING: if you want to copy this into a separate file,
# replace all double backslashes `\\` with a single backslash `\`.
stdlib_asm_code: str = """
.global _start
.global print_int
.global print_bool
.global read_int
.extern main
.section .text
# BEGIN START (we skip this part when linking with C)
# ***** Function '_start' *****
# Calls function 'main' and halts the program
_start:
call main
movq $60, %rax
xorq %rdi, %rdi
syscall
# END START
# ***** Function 'print_int' *****
# Prints a 64-bit signed integer followed by a newline.
#
# We'll build up the digits to print on the stack.
# We generate the least significant digit first,
# and the stack grows downward, so that works out nicely.
#
# Algorithm:
# push(newline)
# if x < 0:
# negative = true
# x = -x
# while x > 0:
# push(digit for (x % 10))
# x = x / 10
# if negative:
# push(minus sign)
# syscall 'write' with pushed data
# return the original argument
#
# Registers:
# - rdi = our input number, which we divide down as we go
# - rsp = stack pointer, pointing to the next character to emit.
# - rbp = pointer to one after the last byte of our output (which grows downward)
# - r9 = whether the number was negative
# - r10 = a copy of the original input, so we can return it
# - rax, rcx and rdx are used by intermediate computations
print_int:
pushq %rbp # Save previous stack frame pointer
movq %rsp, %rbp # Set stack frame pointer
movq %rdi, %r10 # Back up original input
decq %rsp # Point rsp at first byte of output
# TODO: this non-alignment confuses debuggers. Use a different register?
# Add newline as the last output byte
movb $10, (%rsp) # ASCII newline = 10
decq %rsp
# Check for zero and negative cases
xorq %r9, %r9
xorq %rax, %rax
cmpq $0, %rdi
je .Ljust_zero
jge .Ldigit_loop
incq %r9 # If < 0, set %r9 to 1
.Ldigit_loop:
cmpq $0, %rdi
je .Ldigits_done # Loop done when input = 0
# Divide rdi by 10
movq %rdi, %rax
movq $10, %rcx
cqto
idivq %rcx # Sets rax = quotient and rdx = remainder
movq %rax, %rdi # The quotient becomes our remaining input
cmpq $0, %rdx # If the remainder is negative (because the input is), negate it
jge .Lnot_negative
negq %rdx
.Lnot_negative:
addq $48, %rdx # ASCII '0' = 48. Add the remainder to get the correct digit.
movb %dl, (%rsp) # Store the digit in the output
decq %rsp
jmp .Ldigit_loop
.Ljust_zero:
movb $48, (%rsp) # ASCII '0' = 48
decq %rsp
.Ldigits_done:
# Add minus sign if negative
cmpq $0, %r9
je .Lminus_done
movb $45, (%rsp) # ASCII '-' = 45
decq %rsp
.Lminus_done:
# Call syscall 'write'
movq $1, %rax # rax = syscall number for write
movq $1, %rdi # rdi = file handle for stdout
# rsi = pointer to message
movq %rsp, %rsi
incq %rsi
# rdx = number of bytes
movq %rbp, %rdx
subq %rsp, %rdx
decq %rdx
syscall
# Restore stack registers and return the original input
movq %rbp, %rsp
popq %rbp
movq %r10, %rax
ret
# ***** Function 'print_bool' *****
# Prints either 'true' or 'false', followed by a newline.
print_bool:
pushq %rbp # Save previous stack frame pointer
movq %rsp, %rbp # Set stack frame pointer
movq %rdi, %r10 # Back up original input
cmpq $0, %rdi # See if the argument is false (i.e. 0)
jne .Ltrue
movq $false_str, %rsi # If so, set %rsi to the address of the string for false
movq $false_str_len, %rdx # and %rdx to the length of that string,
jmp .Lwrite
.Ltrue:
movq $true_str, %rsi # otherwise do the same with the string for true.
movq $true_str_len, %rdx
.Lwrite:
# Call syscall 'write'
movq $1, %rax # rax = syscall number for write
movq $1, %rdi # rdi = file handle for stdout
# rsi = pointer to message (already set above)
# rdx = number of bytes (already set above)
syscall
# Restore stack registers and return the original input
movq %rbp, %rsp
popq %rbp
movq %r10, %rax
ret
true_str:
.ascii "true\\n"
true_str_len = . - true_str
false_str:
.ascii "false\\n"
false_str_len = . - false_str
# ***** Function 'read_int' *****
# Reads an integer from stdin, skipping non-digit characters, until a newline.
#
# To avoid the complexity of buffering, it very inefficiently
# makes a syscall to read each byte.
#
# It crashes the program if input could not be read.
read_int:
pushq %rbp # Save previous stack frame pointer
movq %rsp, %rbp # Set stack frame pointer
pushq %r12 # Back up r12 since it's callee-saved
pushq $0 # Reserve space for input
# (we only write the lowest byte,
# but loading 64-bits at once is easier)
xorq %r9, %r9 # Clear r9 - it'll store the minus sign
xorq %r10, %r10 # Clear r10 - it'll accumulate our output
# Skip r11 - syscalls destroy it
xorq %r12, %r12 # Clear r12 - it'll count the number of input bytes read.
# Loop until a newline or end of input is encountered
.Lloop:
# Call syscall 'read'
xorq %rax, %rax # syscall number for read = 0
xorq %rdi, %rdi # file handle for stdin = 0
movq %rsp, %rsi # rsi = pointer to buffer
movq $1, %rdx # rdx = buffer size
syscall # result in rax = number of bytes read,
# or 0 on end of input, -1 on error
# Check return value: either -1, 0 or 1.
cmpq $0, %rax
jg .Lno_error
je .Lend_of_input
jmp .Lerror
.Lend_of_input:
cmpq $0, %r12
je .Lerror # If we've read no input, it's an error.
jmp .Lend # Otherwise complete reading this input.
.Lno_error:
incq %r12 # Increment input byte counter
movq (%rsp), %r8 # Load input byte to r8
# If the input byte is 10 (newline), exit the loop
cmpq $10, %r8
je .Lend
# If the input byte is 45 (minus sign), negate r9
cmpq $45, %r8
jne .Lnegation_done
xorq $1, %r9
.Lnegation_done:
# If the input byte is not between 48 ('0') and 57 ('9')
# then skip it as a junk character.
cmpq $48, %r8
jl .Lloop
cmpq $57, %r8
jg .Lloop
# Subtract 48 to get a digit 0..9
subq $48, %r8
# Shift the digit onto the result
imulq $10, %r10
addq %r8, %r10
jmp .Lloop
.Lend:
# If it's a negative number, negate the result
cmpq $0, %r9
je .Lfinal_negation_done
neg %r10
.Lfinal_negation_done:
# Restore stack registers and return the result
popq %r12
movq %rbp, %rsp
popq %rbp
movq %r10, %rax
ret
.Lerror:
# Write error message to stderr with syscall 'write'
movq $1, %rax
movq $2, %rdi
movq $read_int_error_str, %rsi
movq $read_int_error_str_len, %rdx
syscall
# Exit the program
movq $60, %rax # Syscall number for exit = 60.
movq $1, %rdi # Set exit code 1.
syscall
read_int_error_str:
.ascii "Error: read_int() failed to read input\\n"
read_int_error_str_len = . - read_int_error_str
"""
+374
View File
@@ -0,0 +1,374 @@
import uuid
from dataclasses import fields
from compiler.ir import IRVar, Instruction, LoadIntConst, LoadBoolConst, Label, CondJump, Jump, Copy, Call, Return, LoadCustomFuncArgs
from compiler.ir_generator import SymTab
from compiler.intrinsics import all_intrinsics, IntrinsicArgs
class Locals:
"""Knows the memory location of every local variable."""
_var_to_location: dict[IRVar, str]
_stack_used: int
def __init__(self, variables: list[IRVar]) -> None:
curr_loc = -8
self._var_to_location = {}
for ir_var in variables:
self._var_to_location[ir_var] = f'{curr_loc}(%rbp)'
curr_loc -= 8
self._stack_used = 8 * len(variables)
def get_ref(self, v: IRVar) -> str:
"""Returns an Assembly reference like `-24(%rbp)`
for the memory location that stores the given variable"""
return self._var_to_location[v]
def stack_used(self) -> int:
"""Returns the number of bytes of stack space needed for the local variables."""
return self._stack_used
def get_all_ir_variables(instructions: list[Instruction]) -> list[IRVar]:
result_list: list[IRVar] = []
result_set: set[IRVar] = set()
def add(v: IRVar) -> None:
if v not in result_set:
result_list.append(v)
result_set.add(v)
for insn in instructions:
for field in fields(insn):
value = getattr(insn, field.name)
if isinstance(value, IRVar):
add(value)
elif isinstance(value, list):
for v in value:
if isinstance(v, IRVar):
add(v)
return result_list
def get_all_labels(instructions: list[Instruction]) -> list[Label]:
result_list: list[Label] = []
for ins in instructions:
if isinstance(ins, Label):
result_list.append(ins)
elif isinstance(ins, Jump):
result_list.append(ins.label)
elif isinstance(ins, CondJump):
result_list.append(ins.then_label)
result_list.append(ins.else_label)
return result_list
def generate_assembly(instructions: list[Instruction]) -> list[str]:
lines = []
def emit(line: str) -> None: lines.append(line)
locals = Locals(
variables=get_all_ir_variables(instructions)
)
# ... Emit initial declarations and stack setup here ...
emit('.global main')
emit('.type main, @function')
emit('main:')
emit('pushq %rbp')
emit('movq %rsp, %rbp')
emit(f'subq ${locals.stack_used()}, %rsp')
for insn in instructions:
emit('# ' + str(insn))
match insn:
case Label():
emit("")
# ".L" prefix marks the symbol as "private".
# This makes GDB backtraces look nicer too:
# https://stackoverflow.com/a/26065570/965979
emit(f'.L{insn.name}:')
case LoadIntConst():
if -2**31 <= insn.value < 2**31:
emit(f'movq ${insn.value}, {locals.get_ref(insn.dest)}')
else:
# Due to a quirk of x86-64, we must use
# a different instruction for large integers.
# It can only write to a register,
# not a memory location, so we use %rax
# as a temporary.
emit(f'movabsq ${insn.value}, %rax')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
case Jump():
emit(f'jmp .L{insn.label.name}')
case LoadBoolConst():
if insn.value:
emit(f'movq $1, {locals.get_ref(insn.dest)}')
else:
emit(f'movq $0, {locals.get_ref(insn.dest)}')
case Copy():
emit(f'movq {locals.get_ref(insn.source)}, %rax')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
case CondJump():
emit(f'cmpq $0, {locals.get_ref(insn.cond)}')
emit(f'jne .L{insn.then_label.name}')
emit(f'jmp .L{insn.else_label.name}')
case Call():
if insn.fun.name in list(all_intrinsics.keys()):
all_intrinsics[insn.fun.name](IntrinsicArgs(
arg_refs=[locals.get_ref(arg) for arg in insn.args],
result_register='%rax',
emit=emit
))
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
else:
if len(insn.args) == 0:
emit(f'callq {insn.fun.name}')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
elif len(insn.args) == 1:
emit(f'movq {locals.get_ref(insn.args[0])}, %rdi')
emit(f'callq {insn.fun.name}')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
elif len(insn.args) == 2:
emit(f'movq {locals.get_ref(insn.args[0])}, %rdi')
emit(f'movq {locals.get_ref(insn.args[1])}, %rsi')
emit(f'callq {insn.fun.name}')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
elif len(insn.args) == 3:
emit(f'movq {locals.get_ref(insn.args[0])}, %rdi')
emit(f'movq {locals.get_ref(insn.args[1])}, %rsi')
emit(f'movq {locals.get_ref(insn.args[2])}, %rdx')
emit(f'callq {insn.fun.name}')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
elif len(insn.args) == 4:
emit(f'movq {locals.get_ref(insn.args[0])}, %rdi')
emit(f'movq {locals.get_ref(insn.args[1])}, %rsi')
emit(f'movq {locals.get_ref(insn.args[2])}, %rdx')
emit(f'movq {locals.get_ref(insn.args[3])}, %rcx')
emit(f'callq {insn.fun.name}')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
elif len(insn.args) == 5:
emit(f'movq {locals.get_ref(insn.args[0])}, %rdi')
emit(f'movq {locals.get_ref(insn.args[1])}, %rsi')
emit(f'movq {locals.get_ref(insn.args[2])}, %rdx')
emit(f'movq {locals.get_ref(insn.args[3])}, %rcx')
emit(f'movq {locals.get_ref(insn.args[4])}, %r8')
emit(f'callq {insn.fun.name}')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
elif len(insn.args) == 6:
emit(f'movq {locals.get_ref(insn.args[0])}, %rdi')
emit(f'movq {locals.get_ref(insn.args[1])}, %rsi')
emit(f'movq {locals.get_ref(insn.args[2])}, %rdx')
emit(f'movq {locals.get_ref(insn.args[3])}, %rcx')
emit(f'movq {locals.get_ref(insn.args[4])}, %r8')
emit(f'movq {locals.get_ref(insn.args[5])}, %r9')
emit(f'callq {insn.fun.name}')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
elif len(insn.args) > 6:
emit(f'movq {locals.get_ref(insn.args[0])}, %rdi')
emit(f'movq {locals.get_ref(insn.args[1])}, %rsi')
emit(f'movq {locals.get_ref(insn.args[2])}, %rdx')
emit(f'movq {locals.get_ref(insn.args[3])}, %rcx')
emit(f'movq {locals.get_ref(insn.args[4])}, %r8')
emit(f'movq {locals.get_ref(insn.args[5])}, %r9')
for i in range(len(insn.args) - 1, 5, -1):
emit(f'pushq {locals.get_ref(insn.args[i])}')
emit(f'callq {insn.fun.name}')
emit(f'addq ${8 * (len(insn.args) - 6)}, %rsp')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
case Return():
pass
emit(f'movq %rbp, %rsp')
emit(f'popq %rbp')
emit(f'ret')
return lines
def generate_assembly_from_dict(instruction_dict: dict[str, list[Instruction]]) -> str:
lines = []
def emit(line: str) -> None: lines.append(line)
emit('.extern print_int')
emit('.extern print_bool')
emit('.extern read_int')
emit('.section .text')
for k, v in instruction_dict.items():
if k != 'main':
labels = get_all_labels(v)
variables=get_all_ir_variables(v)
locals = Locals(variables)
st = SymTab[str]()
for label in labels:
label_uuid = str(uuid.uuid4()).replace('-', '')
st.assign(label.name, label_uuid)
emit(f'.global {k}')
emit(f'.type {k}, @function')
emit(f'{k}:')
emit('pushq %rbp')
emit('movq %rsp, %rbp')
emit(f'subq ${locals.stack_used()}, %rsp')
for insn in v:
emit('# ' + str(insn))
match insn:
case LoadCustomFuncArgs():
for i in range(len(insn.args)):
reg = ['%rdi', '%rsi', '%rdx', '%rcx', '%r8', '%r9'][i]
var = insn.args[i]
emit(f'movq {reg}, {locals.get_ref(var)}')
case Label():
emit("")
label_uuid = st.lookup(insn.name)
emit(f'.L{label_uuid}:')
case LoadIntConst():
if -2**31 <= insn.value < 2**31:
emit(f'movq ${insn.value}, {locals.get_ref(insn.dest)}')
else:
emit(f'movabsq ${insn.value}, %rax')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
case Jump():
label_uuid = st.lookup(insn.label.name)
emit(f'jmp .L{label_uuid}')
case LoadBoolConst():
if insn.value:
emit(f'movq $1, {locals.get_ref(insn.dest)}')
else:
emit(f'movq $0, {locals.get_ref(insn.dest)}')
case Copy():
emit(f'movq {locals.get_ref(insn.source)}, %rax')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
case CondJump():
emit(f'cmpq $0, {locals.get_ref(insn.cond)}')
then_label_uuid = st.lookup(insn.then_label.name)
emit(f'jne .L{then_label_uuid}')
else_label_uuid = st.lookup(insn.else_label.name)
emit(f'jmp .L{else_label_uuid}')
case Call():
if insn.fun.name in list(all_intrinsics.keys()):
all_intrinsics[insn.fun.name](IntrinsicArgs(
arg_refs=[locals.get_ref(arg) for arg in insn.args],
result_register='%rax',
emit=emit
))
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
else:
if len(insn.args) == 0:
emit(f'callq {insn.fun.name}')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
elif len(insn.args) == 1:
emit(f'movq {locals.get_ref(insn.args[0])}, %rdi')
emit(f'callq {insn.fun.name}')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
elif len(insn.args) == 2:
emit(f'movq {locals.get_ref(insn.args[0])}, %rdi')
emit(f'movq {locals.get_ref(insn.args[1])}, %rsi')
emit(f'callq {insn.fun.name}')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
elif len(insn.args) == 3:
emit(f'movq {locals.get_ref(insn.args[0])}, %rdi')
emit(f'movq {locals.get_ref(insn.args[1])}, %rsi')
emit(f'movq {locals.get_ref(insn.args[2])}, %rdx')
emit(f'callq {insn.fun.name}')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
elif len(insn.args) == 4:
emit(f'movq {locals.get_ref(insn.args[0])}, %rdi')
emit(f'movq {locals.get_ref(insn.args[1])}, %rsi')
emit(f'movq {locals.get_ref(insn.args[2])}, %rdx')
emit(f'movq {locals.get_ref(insn.args[3])}, %rcx')
emit(f'callq {insn.fun.name}')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
elif len(insn.args) == 5:
emit(f'movq {locals.get_ref(insn.args[0])}, %rdi')
emit(f'movq {locals.get_ref(insn.args[1])}, %rsi')
emit(f'movq {locals.get_ref(insn.args[2])}, %rdx')
emit(f'movq {locals.get_ref(insn.args[3])}, %rcx')
emit(f'movq {locals.get_ref(insn.args[4])}, %r8')
emit(f'callq {insn.fun.name}')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
elif len(insn.args) == 6:
emit(f'movq {locals.get_ref(insn.args[0])}, %rdi')
emit(f'movq {locals.get_ref(insn.args[1])}, %rsi')
emit(f'movq {locals.get_ref(insn.args[2])}, %rdx')
emit(f'movq {locals.get_ref(insn.args[3])}, %rcx')
emit(f'movq {locals.get_ref(insn.args[4])}, %r8')
emit(f'movq {locals.get_ref(insn.args[5])}, %r9')
emit(f'callq {insn.fun.name}')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
elif len(insn.args) > 6:
emit(f'movq {locals.get_ref(insn.args[0])}, %rdi')
emit(f'movq {locals.get_ref(insn.args[1])}, %rsi')
emit(f'movq {locals.get_ref(insn.args[2])}, %rdx')
emit(f'movq {locals.get_ref(insn.args[3])}, %rcx')
emit(f'movq {locals.get_ref(insn.args[4])}, %r8')
emit(f'movq {locals.get_ref(insn.args[5])}, %r9')
for i in range(len(insn.args) - 1, 5, -1):
emit(f'pushq {locals.get_ref(insn.args[i])}')
emit(f'callq {insn.fun.name}')
emit(f'addq ${8 * (len(insn.args) - 6)}, %rsp')
emit(f'movq %rax, {locals.get_ref(insn.dest)}')
case Return():
emit(f'movq {locals.get_ref(insn.result)}, %rax')
emit(f'movq %rbp, %rsp')
emit(f'popq %rbp')
emit(f'ret')
main_lines: list[str] = generate_assembly(instruction_dict['main'])
lines.extend(main_lines)
return '\n'.join(lines)
+85
View File
@@ -0,0 +1,85 @@
from __future__ import annotations
from dataclasses import dataclass, field
from compiler.tokenizer import Location
from compiler.types import Type, Unit_Instance
@dataclass
class Expression:
"""Base class for AST nodes representing expressions."""
location: Location
type: Type = field(kw_only=True, default=Unit_Instance)
@dataclass
class Literal(Expression):
value: int | bool | None
@dataclass
class Identifier(Expression):
name: str
@dataclass
class BinaryOp(Expression):
"""AST node for a binary operation like `A + B`"""
left: Expression
op: str
right: Expression
@dataclass
class UnaryOp(Expression):
op: str
right: Expression
@dataclass
class If(Expression):
cond_exp: Expression
then_exp: Expression
else_exp: Expression | None
@dataclass
class While(Expression):
while_exp: Expression
do_exp: Expression
@dataclass
class FunctionArg():
name: str
type: Type
@dataclass
class FunctionDef():
name: str
args: list[FunctionArg]
return_type: Type
block: Block
@dataclass
class Module():
function_defs: list[FunctionDef]
block: Block
@dataclass
class FunctionCall(Expression):
name: str
args: list[Expression]
@dataclass
class Block(Expression):
statements: list[Expression]
@dataclass
class Var(Expression):
name: str
value: Expression
type_f: Type | None
@dataclass
class Break(Expression):
pass
@dataclass
class Continue(Expression):
pass
@dataclass
class Return(Expression):
result: Expression
+311
View File
@@ -0,0 +1,311 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from compiler import ast
from compiler.tokenizer import L
from types import FunctionType
def unary_op_negate_int(node: ast.UnaryOp, sym_tab: SymTab) -> int:
right: Any = interpret(node.right, sym_tab)
if type(right) is int:
return -right
else:
raise Exception()
def unary_op_negate_bool(node: ast.UnaryOp, sym_tab: SymTab) -> bool:
right: Any = interpret(node.right, sym_tab)
if type(right) is bool:
return not right
else:
raise Exception()
def binary_op_and(node: ast.BinaryOp, sym_tab: SymTab) -> bool:
a: Any = interpret(node.left, sym_tab)
if not a:
return False
b: Any = interpret(node.right, sym_tab)
return b
def binary_op_or(node: ast.BinaryOp, sym_tab: SymTab) -> bool:
a: Any = interpret(node.left, sym_tab)
if a:
return True
b: Any = interpret(node.right, sym_tab)
return b
def binary_op_add(node: ast.BinaryOp, sym_tab: SymTab) -> int:
a: Any = interpret(node.left, sym_tab)
b: Any = interpret(node.right, sym_tab)
return a + b
def binary_op_subtract(node: ast.BinaryOp, sym_tab: SymTab) -> int:
a: Any = interpret(node.left, sym_tab)
b: Any = interpret(node.right, sym_tab)
return a - b
def binary_op_multiply(node: ast.BinaryOp, sym_tab: SymTab) -> int:
a: Any = interpret(node.left, sym_tab)
b: Any = interpret(node.right, sym_tab)
return a * b
def binary_op_divide(node: ast.BinaryOp, sym_tab: SymTab) -> int:
a: Any = interpret(node.left, sym_tab)
b: Any = interpret(node.right, sym_tab)
return int(a / b)
def binary_op_modulo(node: ast.BinaryOp, sym_tab: SymTab) -> int:
a: Any = interpret(node.left, sym_tab)
b: Any = interpret(node.right, sym_tab)
return a % b
def binary_op_lt(node: ast.BinaryOp, sym_tab: SymTab) -> bool:
a: Any = interpret(node.left, sym_tab)
b: Any = interpret(node.right, sym_tab)
return a < b
def binary_op_gt(node: ast.BinaryOp, sym_tab: SymTab) -> bool:
a: Any = interpret(node.left, sym_tab)
b: Any = interpret(node.right, sym_tab)
return a > b
def binary_op_gte(node: ast.BinaryOp, sym_tab: SymTab) -> bool:
a: Any = interpret(node.left, sym_tab)
b: Any = interpret(node.right, sym_tab)
return a >= b
def binary_op_lte(node: ast.BinaryOp, sym_tab: SymTab) -> bool:
a: Any = interpret(node.left, sym_tab)
b: Any = interpret(node.right, sym_tab)
return a <= b
def binary_op_diff(node: ast.BinaryOp, sym_tab: SymTab) -> bool:
a: Any = interpret(node.left, sym_tab)
b: Any = interpret(node.right, sym_tab)
return a != b
def binary_op_eq(node: ast.BinaryOp, sym_tab: SymTab) -> bool:
a: Any = interpret(node.left, sym_tab)
b: Any = interpret(node.right, sym_tab)
return a == b
def built_in_function_print_int(x: int) -> None:
print(x)
def built_in_function_print_bool(x: bool) -> None:
if x:
print('true')
else:
print('false')
def built_in_function_read_int() -> int:
return int(input())
class SymTab:
locals: dict[str, dict[str, FunctionType] | Any]
parent: SymTab | None
def __init__(self, locals: dict = {}, parent: SymTab | None = None) -> None:
self.locals = locals
self.parent = parent
if parent is None:
self.locals.update({
'binary_ops': {
'and': binary_op_and,
'or' : binary_op_or,
'+': binary_op_add,
'-': binary_op_subtract,
'*': binary_op_multiply,
'/': binary_op_divide,
'%': binary_op_modulo,
'<': binary_op_lt,
'>': binary_op_gt,
'>=': binary_op_gte,
'<=': binary_op_lte,
'!=': binary_op_diff,
'==': binary_op_eq,
},
'unary_ops': {
'-': unary_op_negate_int,
'not': unary_op_negate_bool
},
'built_in_functions': {
'print_int': built_in_function_print_int,
'print_bool': built_in_function_print_bool,
'read_int': built_in_function_read_int
}
})
def get_top_level(self) -> SymTab:
if self.parent is None:
return self
return self.parent.get_top_level()
def assign(self, var_name: str, var_value: Any) -> None:
self.locals[var_name] = var_value
def assign_recursive(self, var_name: str, var_value: Any) -> None:
if self.locals.get(var_name) is not None:
self.locals[var_name] = var_value
else:
if self.parent is not None:
self.parent.assign_recursive(var_name, var_value)
else:
raise Exception()
def lookup(self, var_name: str) -> Any:
return self.locals.get(var_name)
def lookup_recursive(self, var_name: str) -> Any:
if self.locals.get(var_name) is not None:
return self.locals.get(var_name)
else:
if self.parent is not None:
return self.parent.lookup_recursive(var_name)
else:
raise Exception()
@dataclass
class Unit:
pass
type Value = int | bool | Unit | FunctionType | None
def interpret(node: ast.Expression | None, sym_tab: SymTab | None = None) -> Value:
if sym_tab is None:
sym_tab = SymTab()
match node:
case None:
return None
case ast.Literal():
if node == ast.Literal(location=L, value=None):
return Unit()
else:
return node.value
case ast.Identifier():
return sym_tab.lookup_recursive(node.name)
case ast.UnaryOp():
if sym_tab.get_top_level().locals['unary_ops'].get(node.op) is not None:
to_be_executed_function = sym_tab.get_top_level().locals['unary_ops'].get(node.op)
assert isinstance(to_be_executed_function, FunctionType)
return to_be_executed_function(node, sym_tab)
else:
raise Exception()
case ast.BinaryOp():
if node.op == '=':
a: Any = node.left
b: Any = interpret(node.right, sym_tab)
if isinstance(a, ast.Identifier):
sym_tab.assign_recursive(a.name, b)
else:
raise Exception()
return b
else:
if sym_tab.get_top_level().locals['binary_ops'].get(node.op) is not None:
to_be_executed_function = sym_tab.get_top_level().locals['binary_ops'].get(node.op)
assert isinstance(to_be_executed_function, FunctionType)
return to_be_executed_function(node, sym_tab)
else:
raise Exception()
case ast.If():
e1 = node.cond_exp
e2 = node.then_exp
e3 = node.else_exp
# if-then
if e3 is None:
e1_value = interpret(e1, sym_tab)
if e1_value == True:
interpret(e2, sym_tab)
return Unit()
elif e1_value == False:
return Unit()
else:
raise Exception()
# if-then-else
else:
e1_value = interpret(e1, sym_tab)
if e1_value == True:
return interpret(e2, sym_tab)
elif e1_value == False:
return interpret(e3, sym_tab)
else:
raise Exception()
case ast.While():
while interpret(node.while_exp, sym_tab):
interpret(node.do_exp, sym_tab)
if not interpret(node.while_exp, sym_tab):
return Unit()
else:
raise Exception()
case ast.FunctionCall():
if node.name == 'print_int':
if len(node.args) != 1:
raise Exception()
arg_value = interpret(node.args[0], sym_tab)
sym_tab.get_top_level().locals['built_in_functions']['print_int'](arg_value)
return Unit()
elif node.name == 'print_bool':
if len(node.args) != 1:
raise Exception()
arg_value = interpret(node.args[0], sym_tab)
sym_tab.get_top_level().locals['built_in_functions']['print_bool'](arg_value)
return Unit()
elif node.name == 'read_int':
return sym_tab.get_top_level().locals['built_in_functions']['read_int']()
else:
# User-defined functions
interpreted_args = []
for arg in node.args:
interpreted_args.append(interpret(arg, sym_tab))
return sym_tab.lookup_recursive(node.name)(*interpreted_args)
case ast.Block():
if node.statements == []:
return Unit()
new_sym_tab = SymTab(locals={}, parent=sym_tab)
for index, statement in enumerate(node.statements):
statement_value = interpret(statement, new_sym_tab)
if index == len(node.statements) - 1:
if statement_value != ast.Literal(location=L, value=None):
return statement_value
else:
return Unit()
case ast.Var():
if sym_tab.lookup(node.name) is not None:
raise Exception()
sym_tab.assign(node.name, interpret(node.value))
return Unit()
return -1 # This line is only added for the linter; it should never be actually reached.
+113
View File
@@ -0,0 +1,113 @@
from dataclasses import dataclass
from typing import Callable
@dataclass
class IntrinsicArgs():
arg_refs: list[str]
result_register: str
emit: Callable[[str], None]
Intrinsic = Callable[[IntrinsicArgs], None]
all_intrinsics: dict[str, Intrinsic] = {}
def _intrinsic(name: str) -> Callable[[Intrinsic], Intrinsic]:
"""Function decorator that registers that function as an intrinsic."""
def wrapper(f: Intrinsic) -> Intrinsic:
assert name not in all_intrinsics
all_intrinsics[name] = f
return f
return wrapper
@_intrinsic("unary_-")
def unary_minus(a: IntrinsicArgs) -> None:
a.emit(f'movq {a.arg_refs[0]}, {a.result_register}')
a.emit(f'negq {a.result_register}')
@_intrinsic("unary_not")
def unary_not(a: IntrinsicArgs) -> None:
a.emit(f'movq {a.arg_refs[0]}, {a.result_register}')
a.emit(f'xorq $1, {a.result_register}')
@_intrinsic("+")
def plus(a: IntrinsicArgs) -> None:
if a.result_register != a.arg_refs[0]:
a.emit(f'movq {a.arg_refs[0]}, {a.result_register}')
a.emit(f'addq {a.arg_refs[1]}, {a.result_register}')
@_intrinsic("-")
def minus(a: IntrinsicArgs) -> None:
if a.result_register != a.arg_refs[0]:
a.emit(f'movq {a.arg_refs[0]}, {a.result_register}')
a.emit(f'subq {a.arg_refs[1]}, {a.result_register}')
@_intrinsic("*")
def multiply(a: IntrinsicArgs) -> None:
if a.result_register != a.arg_refs[0]:
a.emit(f'movq {a.arg_refs[0]}, {a.result_register}')
a.emit(f'imulq {a.arg_refs[1]}, {a.result_register}')
@_intrinsic("/")
def divide(a: IntrinsicArgs) -> None:
a.emit(f'movq {a.arg_refs[0]}, %rax')
a.emit('cqto') # TODO: explain
a.emit(f'idivq {a.arg_refs[1]}')
if a.result_register != '%rax':
a.emit(f'movq %rax, {a.result_register}')
@_intrinsic("%")
def remainder(a: IntrinsicArgs) -> None:
a.emit(f'movq {a.arg_refs[0]}, %rax')
a.emit('cqto')
a.emit(f'idivq {a.arg_refs[1]}')
if a.result_register != '%rdx':
a.emit(f'movq %rdx, {a.result_register}')
@_intrinsic("==")
def eq(a: IntrinsicArgs) -> None:
_int_comparison(a, 'sete')
@_intrinsic("!=")
def ne(a: IntrinsicArgs) -> None:
_int_comparison(a, 'setne')
@_intrinsic("<")
def lt(a: IntrinsicArgs) -> None:
_int_comparison(a, 'setl')
@_intrinsic("<=")
def le(a: IntrinsicArgs) -> None:
_int_comparison(a, 'setle')
@_intrinsic(">")
def gt(a: IntrinsicArgs) -> None:
_int_comparison(a, 'setg')
@_intrinsic(">=")
def ge(a: IntrinsicArgs) -> None:
_int_comparison(a, 'setge')
def _int_comparison(a: IntrinsicArgs, setcc_insn: str) -> None:
a.emit('xor %rax, %rax')
a.emit(f'movq {a.arg_refs[0]}, %rdx')
a.emit(f'cmpq {a.arg_refs[1]}, %rdx')
a.emit(f'{setcc_insn} %al')
if a.result_register != '%rax':
a.emit(f'movq %rax, {a.result_register}')
+86
View File
@@ -0,0 +1,86 @@
from __future__ import annotations
from dataclasses import dataclass, fields
from typing import Any
from compiler.tokenizer import Location
@dataclass(frozen=True)
class IRVar:
"""Represents the name of a memory location or built-in."""
name: str
def __str__(self) -> str:
return self.name
@dataclass(frozen=True)
class Instruction():
"""Base class for IR instructions."""
location: Location
def __str__(self) -> str:
"""Returns a string representation similar to
our IR code examples, e.g. 'LoadIntConst(3, x1)'"""
def format_value(v: Any) -> str:
if isinstance(v, list):
return f'[{", ".join(format_value(e) for e in v)}]'
else:
return str(v)
args = ', '.join(
format_value(getattr(self, field.name))
for field in fields(self)
if field.name != 'location'
)
return f'{type(self).__name__}({args})'
@dataclass(frozen=True)
class LoadBoolConst(Instruction):
"""Loads a boolean constant value to `dest`."""
value: bool
dest: IRVar
@dataclass(frozen=True)
class LoadIntConst(Instruction):
"""Loads an int constant value to `dest`."""
value: int
dest: IRVar
@dataclass(frozen=True)
class Copy(Instruction):
"""Copies a value from one variable to another."""
source: IRVar
dest: IRVar
@dataclass(frozen=True)
class Call(Instruction):
"""Calls a function or built-in."""
fun: IRVar
args: list[IRVar]
dest: IRVar
@dataclass(frozen=True)
class Jump(Instruction):
"""Unconditionally continues execution from the given label."""
label: Label
@dataclass(frozen=True)
class CondJump(Instruction):
"""Continues execution from `then_label` if `cond` is true, otherwise from `else_label`."""
cond: IRVar
then_label: Label
else_label: Label
@dataclass(frozen=True)
class Return(Instruction):
"""Returns the value of a function."""
result: IRVar
@dataclass(frozen=True)
class LoadCustomFuncArgs(Instruction):
args: list[IRVar]
@dataclass(frozen=True)
class Label(Instruction):
"""Marks the destination of a jump instruction."""
name: str
+424
View File
@@ -0,0 +1,424 @@
from __future__ import annotations
from compiler.ir import IRVar, Instruction, LoadIntConst, LoadBoolConst, Call, Label, CondJump, Jump, Copy, Return, LoadCustomFuncArgs
import compiler.ast as ast
from compiler.types import Bool_Instance, Int_Instance, Type, Unit_Instance, FunType, Any_Instance
from compiler.tokenizer import L, Location
import uuid
def print_instructions(ls: list[Instruction]) -> str:
output_str = ''
for instruction in ls:
output_str += f'{instruction.__str__()}\n'
return output_str
class SymTab[T]:
locals: dict[str, T]
parent: SymTab | None
def __init__(self, locals: dict = {}, parent: SymTab | None = None) -> None:
self.locals = locals
self.parent = parent
def get_top_level(self) -> SymTab:
if self.parent is None:
return self
return self.parent.get_top_level()
def assign(self, var_name: str, var_value: T) -> None:
self.locals[var_name] = var_value
def assign_recursive(self, var_name: str, var_value: T) -> None:
if self.locals.get(var_name) is not None:
self.locals[var_name] = var_value
else:
if self.parent is not None:
self.parent.assign_recursive(var_name, var_value)
else:
raise Exception()
def lookup(self, var_name: str) -> T:
lookup_result = self.locals.get(var_name)
if lookup_result is not None:
return lookup_result
else:
raise Exception()
def lookup_recursive(self, var_name: str) -> T:
lookup_result = self.locals.get(var_name)
if lookup_result is not None:
return lookup_result
else:
if self.parent is not None:
return self.parent.lookup_recursive(var_name)
else:
raise Exception(var_name)
binary_op_common_func_type = FunType(params=[Int_Instance, Int_Instance], result=Int_Instance)
binary_op_int_comparison_func_type = FunType(params=[Int_Instance, Int_Instance], result=Bool_Instance)
binary_op_and_or_func_type = FunType(params=[Bool_Instance, Bool_Instance], result=Bool_Instance)
binary_op_eq_func_type = FunType(params=[Any_Instance, Any_Instance], result=Bool_Instance)
unary_op_int_func_type = FunType(params=[Int_Instance], result=Int_Instance)
unary_op_bool_func_type = FunType(params=[Bool_Instance], result=Bool_Instance)
print_int_func_type = FunType(params=[Int_Instance], result=Unit_Instance)
print_bool_func_type = FunType(params=[Bool_Instance], result=Unit_Instance)
read_int_func_type = FunType(params=[], result=Int_Instance)
root_types: dict[IRVar, Type] = {
IRVar('and'): binary_op_and_or_func_type,
IRVar('or') : binary_op_and_or_func_type,
IRVar('+'): binary_op_common_func_type,
IRVar('-'): binary_op_common_func_type,
IRVar('*'): binary_op_common_func_type,
IRVar('/'): binary_op_common_func_type,
IRVar('%'): binary_op_common_func_type,
IRVar('<'): binary_op_int_comparison_func_type,
IRVar('>'): binary_op_int_comparison_func_type,
IRVar('>='): binary_op_int_comparison_func_type,
IRVar('<='): binary_op_int_comparison_func_type,
IRVar('!='): binary_op_eq_func_type,
IRVar('=='): binary_op_eq_func_type,
IRVar('unary_-'): unary_op_int_func_type,
IRVar('unary_not'): unary_op_bool_func_type,
IRVar('print_int'): print_int_func_type,
IRVar('print_bool'): print_bool_func_type,
IRVar('read_int'): read_int_func_type,
}
def generate_ir(
# 'root_types' parameter should map all global names
# like 'print_int' and '+' to their types.
root_symtab: SymTab[IRVar],
root_types: dict[IRVar, Type],
root_expr: ast.Expression,
is_func_def: bool = False,
custom_func_ir_vars: list[IRVar] = []
) -> list[Instruction]:
var_types: dict[IRVar, Type] = root_types.copy()
# 'var_unit' is used when an expression's type is 'Unit'.
var_unit = IRVar('unit')
var_types[var_unit] = Unit_Instance
def new_var(t: Type) -> IRVar:
# Create a new unique IR variable and
# add it to var_types
new_ir_var = IRVar(str(uuid.uuid4()).replace('-', ''))
var_types[new_ir_var] = t
return new_ir_var
def new_label(location: Location = L) -> Label:
return Label(location, str(uuid.uuid4()).replace('-', ''))
# We collect the IR instructions that we generate
# into this list.
ins: list[Instruction] = []
# This function visits an AST node,
# appends IR instructions to 'ins',
# and returns the IR variable where
# the emitted IR instructions put the result.
#
# It uses a symbol table to map local variables
# (which may be shadowed) to unique IR variables.
# The symbol table will be updated in the same way as
# in the interpreter and type checker.
def visit(st: SymTab[IRVar], lst: SymTab[Label], expr: ast.Expression) -> IRVar:
loc = expr.location
match expr:
case ast.Literal():
# Create an IR variable to hold the value,
# and emit the correct instruction to
# load the constant value.
match expr.value:
case bool():
var = new_var(Bool_Instance)
ins.append(LoadBoolConst(loc, expr.value, var))
case int():
var = new_var(Int_Instance)
ins.append(LoadIntConst(loc, expr.value, var))
case None:
var = var_unit
case _:
raise Exception(f'{loc}: unsupported literal: {type(expr.value)}')
# Return the variable that holds
# the loaded value.
return var
case ast.Identifier():
# Look up the IR variable that corresponds to
# the source code variable.
return st.lookup_recursive(expr.name)
case ast.BinaryOp():
if expr.op == '=':
if not isinstance(expr.left, ast.Identifier):
raise Exception()
ir_var_left = visit(st, lst, expr.left)
ir_var_right = visit(st, lst, expr.right)
ins.append(Copy(loc, ir_var_right, ir_var_left))
return ir_var_right
elif expr.op == 'and':
v_result = new_var(Bool_Instance)
v_left = visit(st, lst, expr.left)
label_1 = new_label()
label_2 = new_label()
label_3 = new_label()
ins.append(CondJump(loc, v_left, label_1, label_2))
ins.append(label_1)
v_right = visit(st, lst, expr.right)
ins.append(Copy(loc, v_right, v_result))
ins.append(Jump(loc, label_3))
ins.append(label_2)
ins.append(Copy(loc, v_left, v_result))
ins.append(label_3)
return v_result
elif expr.op == 'or':
v_result = new_var(Bool_Instance)
v_left = visit(st, lst, expr.left)
label_1 = new_label()
label_2 = new_label()
label_3 = new_label()
ins.append(CondJump(loc, v_left, label_1, label_2))
ins.append(label_1)
ins.append(Copy(loc, v_left, v_result))
ins.append(Jump(loc, label_3))
ins.append(label_2)
v_right = visit(st, lst, expr.right)
ins.append(Copy(loc, v_right, v_result))
ins.append(label_3)
return v_result
else:
# Ask the symbol table to return the variable that refers
# to the operator to call.
var_op = st.get_top_level().lookup(expr.op)
# Recursively emit instructions to calculate the operands.
var_left = visit(st, lst, expr.left)
var_right = visit(st, lst, expr.right)
# Generate variable to hold the result.
var_result = new_var(expr.type)
# Emit a Call instruction that writes to that variable.
ins.append(Call(loc, var_op, [var_left, var_right], var_result))
return var_result
case ast.UnaryOp():
var_op = st.get_top_level().lookup(f'unary_{expr.op}')
var_right = visit(st, lst, expr.right)
var_result = new_var(expr.type)
ins.append(Call(loc, var_op, [var_right], var_result))
return var_result
case ast.If():
if expr.else_exp is None:
# Create (but don't emit) some jump targets.
l_then = new_label(expr.then_exp.location)
l_end = new_label()
# Recursively emit instructions for
# evaluating the condition.
var_cond = visit(st, lst, expr.cond_exp)
# Emit a conditional jump instruction
# to jump to 'l_then' or 'l_end',
# depending on the content of 'var_cond'.
ins.append(CondJump(loc, var_cond, l_then, l_end))
# Emit the label that marks the beginning of
# the "then" branch.
ins.append(l_then)
# Recursively emit instructions for the "then" branch.
visit(st, lst, expr.then_exp)
# Emit the label that we jump to
# when we don't want to go to the "then" branch.
ins.append(l_end)
# An if-then expression doesn't return anything, so we
# return a special variable "unit".
return var_unit
else:
then_label = new_label(expr.then_exp.location)
else_label = new_label(expr.else_exp.location)
end_label = new_label()
result_ir_var = new_var(expr.type)
ir_var_cond = visit(st, lst, expr.cond_exp)
ins.append(CondJump(loc, ir_var_cond, then_label, else_label))
ins.append(then_label)
then_ir_var = visit(st, lst, expr.then_exp)
ins.append(Copy(loc, then_ir_var, result_ir_var))
ins.append(Jump(loc, end_label))
ins.append(else_label)
else_ir_var = visit(st, lst, expr.else_exp)
ins.append(Copy(loc, else_ir_var, result_ir_var))
ins.append(end_label)
return result_ir_var
case ast.While():
l0 = new_label(expr.while_exp.location)
l1 = new_label(expr.do_exp.location)
l2 = new_label()
new_lst: SymTab[Label] = SymTab(locals={}, parent=lst)
new_lst.assign('start_label', l0)
new_lst.assign('end_label', l2)
ins.append(l0)
x1 = visit(st, new_lst, expr.while_exp)
ins.append(CondJump(loc, x1, l1, l2))
ins.append(l1)
visit(st, new_lst, expr.do_exp)
ins.append(Jump(loc, l0))
ins.append(l2)
return var_unit
case ast.FunctionCall():
arg_ir_vars = []
for arg in expr.args:
arg_ir_var = visit(st, lst, arg)
arg_ir_vars.append(arg_ir_var)
ft = st.lookup_recursive(expr.name)
riv = new_var(expr.type)
ins.append(Call(loc, ft, [*arg_ir_vars], riv))
return riv
case ast.Block():
last_expr = expr.statements[len(expr.statements) - 1]
new_st: SymTab = SymTab(locals={}, parent=st)
if isinstance(last_expr, ast.Literal) and last_expr.value is None:
for index, statement in enumerate(expr.statements):
if index == len(expr.statements) - 1:
return var_unit
else:
visit(new_st, lst, statement)
else:
for index, statement in enumerate(expr.statements):
if index == len(expr.statements) - 1:
last_statement_ir_var = visit(new_st, lst, statement)
return last_statement_ir_var
else:
visit(new_st, lst, statement)
case ast.Var():
if st.locals.get(expr.name) is not None:
raise Exception()
ir_var_new = new_var(expr.type)
ir_var_for_new_var_value = visit(st, lst, expr.value)
ins.append(Copy(loc, ir_var_for_new_var_value, ir_var_new))
st.assign(expr.name, ir_var_new)
return var_unit
case ast.Break():
end_label = lst.lookup('end_label')
ins.append(Jump(loc, end_label))
return var_unit
case ast.Continue():
start_label = lst.lookup('start_label')
ins.append(Jump(loc, start_label))
return var_unit
case ast.Return():
rs_ir_var = visit(st, lst, expr.result)
ins.append(Return(loc, rs_ir_var))
return var_unit
return var_unit # This line is only added for the linter; it should never be actually reached.
# # Convert 'root_types' into a SymTab
# # that maps all available global names to
# # IR variables of the same name.
# # In the Assembly generator stage, we will give
# # definitions for these globals. For now,
# # they just need to exist.
# root_symtab = SymTab[IRVar]()
# for v in root_types.keys():
# root_symtab.assign(v.name, v)
loop_symtab = SymTab[Label]()
if is_func_def:
ins.append(LoadCustomFuncArgs(L, custom_func_ir_vars))
# Start visiting the AST from the root.
var_final_result = visit(root_symtab, loop_symtab, root_expr)
if is_func_def:
ins.append(Return(root_expr.location, var_final_result))
else:
if var_types[var_final_result] == Int_Instance:
print_int_ir_var = root_symtab.lookup('print_int')
ins.append(Call(root_expr.location, print_int_ir_var, [var_final_result], new_var(Int_Instance)))
elif var_types[var_final_result] == Bool_Instance:
print_bool_ir_var = root_symtab.lookup('print_bool')
ins.append(Call(root_expr.location, print_bool_ir_var, [var_final_result], new_var(Bool_Instance)))
return ins
def generate_ir_from_module(module: ast.Module) -> dict[str, list[Instruction]]:
result: dict[str, list[Instruction]] = {}
extended_types_for_main = root_types.copy()
for function_def in module.function_defs:
fdiv = IRVar(function_def.name)
extended_types_for_main.update({ fdiv: function_def.return_type })
for function_def in module.function_defs:
extended_types = extended_types_for_main.copy()
custom_func_ir_vars: list[IRVar] = []
for arg in function_def.args:
irv = IRVar(arg.name)
custom_func_ir_vars.append(irv)
extended_types.update({ irv: arg.type })
fdiv = IRVar(function_def.name)
extended_types.update({ fdiv: function_def.return_type})
root_symtab = SymTab[IRVar]()
for v in extended_types.keys():
root_symtab.assign(v.name, v)
result[function_def.name] = generate_ir(root_symtab, extended_types, function_def.block, True, custom_func_ir_vars)
root_symtab = SymTab[IRVar]()
for v in extended_types_for_main.keys():
root_symtab.assign(v.name, v)
result['main'] = generate_ir(root_symtab, extended_types_for_main, module.block)
return result
+337
View File
@@ -0,0 +1,337 @@
from compiler.tokenizer import Token
import compiler.ast as ast
from compiler.types import FunType, ParamsType
from compiler.type_checker import Unit_Instance, Int_Instance, Bool_Instance, Type
def parse(tokens: list[Token]) -> ast.Module:
if len(tokens) == 0:
raise Exception('len(tokens) == 0')
# This keeps track of which token we're looking at.
pos = 0
end_token = Token(
location=tokens[-1].location,
type='end',
text='',
)
# 'peek()' returns the token at 'pos',
# or a special 'end' token if we're past the end
# of the token list.
# This way we don't have to worry about going past
# the end elsewhere.
def peek() -> Token:
nonlocal pos, end_token
if pos < len(tokens):
return tokens[pos]
else:
return end_token
last_consumed_token: Token | None = None
# 'consume(expected)' returns the token at 'pos'
# and moves 'pos' forward.
#
# If the optional parameter 'expected' is given,
# it checks that the token being consumed has that text.
# If 'expected' is a list, then the token must have
# one of the texts in the list.
def consume(expected: str | list[str] | None = None) -> Token:
nonlocal pos, last_consumed_token
token = peek()
if isinstance(expected, str) and token.text != expected:
raise Exception(f'{token.location}: expected "{expected}", got "{token.text}"')
if isinstance(expected, list) and token.text not in expected:
comma_separated = ', '.join([f'"{e}"' for e in expected])
raise Exception(f'{token.location}: expected one of: {comma_separated}')
pos += 1
last_consumed_token = token
return token
binary_operators = [
(['='], 'right'),
(['or'], 'left'),
(['and'], 'left'),
(['==', '!='], 'left'),
(['<', '<=', '>', '>='], 'left'),
(['+', '-'], 'left'),
(['*', '/', '%'], 'left'),
]
def parse_expression(binary_precedence_level: int = 0, allow_var: bool = False) -> ast.Expression:
nonlocal pos
if binary_precedence_level == len(binary_operators):
if peek().text == 'not':
token = consume('not')
return ast.UnaryOp(location=token.location, op='not', right=parse_expression(binary_precedence_level))
elif peek().text == '-':
token = consume('-')
return ast.UnaryOp(location=token.location, op='-', right=parse_expression(binary_precedence_level))
else:
return parse_factor(allow_var=allow_var)
else:
left = parse_expression(binary_precedence_level + 1, allow_var)
precedence_level_array, associativity = binary_operators[binary_precedence_level]
while peek().text in precedence_level_array:
operator_token = consume()
operator = operator_token.text
next_level = 1 if associativity == 'left' else 0
right = parse_expression(binary_precedence_level + next_level, False)
left = ast.BinaryOp(
location=left.location,
left=left,
op=operator,
right=right
)
return left
def parse_factor(allow_var: bool = False) -> ast.Expression:
if peek().text == 'if':
return parse_if()
elif peek().text == 'while':
return parse_while()
elif peek().text == 'var':
if allow_var:
return parse_var()
else:
raise Exception('"var" is only allowed directly inside blocks {} and in top-level expressions')
elif peek().text == '(':
return parse_parenthesized()
elif peek().text == '{':
return parse_block()
elif peek().type == 'int_literal':
return parse_int_literal()
elif peek().type == 'bool_literal':
return parse_boolean_literal()
elif peek().type == 'identifier':
return parse_identifier()
elif peek().text == 'break':
token = consume('break')
return ast.Break(location=token.location, type=Unit_Instance)
elif peek().text == 'continue':
token = consume('continue')
return ast.Continue(location=token.location, type=Unit_Instance)
elif peek().text == 'return':
token = consume('return')
result = parse_expression()
return ast.Return(location=token.location, result=result)
else:
raise Exception(f'{peek().location}: expected "(", an integer literal or an identifier')
def parse_if() -> ast.If:
token = consume('if')
if_exp = parse_expression()
then_exp = None
else_exp = None
if peek().text == 'then':
consume('then')
then_exp = parse_expression()
else:
raise Exception(f'{peek().location}: "if" not followed by "then"')
if peek().text == 'else':
consume('else')
else_exp = parse_expression()
return ast.If(location=token.location, cond_exp=if_exp, then_exp=then_exp, else_exp=else_exp)
def parse_while() -> ast.While:
token = consume('while')
while_exp = parse_expression()
do_exp = None
if peek().text == 'do':
consume('do')
do_exp = parse_expression()
else:
raise Exception(f'{peek().location}: "while" not followed by "do"')
return ast.While(location=token.location, while_exp=while_exp, do_exp=do_exp)
def parse_parenthesized() -> ast.Expression:
consume('(')
# Recursively call the top level parsing function
# to parse whatever is inside the parentheses.
expr = parse_expression()
if peek().text != ')':
raise Exception(f'{peek().location}: "(" not followed by ")"')
consume(')')
return expr
def parse_var() -> ast.Var:
token = consume('var')
var_identifier = parse_identifier()
var_type = None
if peek().text == ':':
consume(':')
var_type = parse_type()
consume('=')
var_value = parse_expression()
return ast.Var(location=token.location, name=var_identifier.name, value=var_value, type_f=var_type)
def parse_type() -> Type:
next_token = peek()
if next_token.text == 'Unit':
consume('Unit')
return Unit_Instance
elif next_token.text == 'Int':
consume('Int')
return Int_Instance
elif next_token.text == 'Bool':
consume('Bool')
return Bool_Instance
elif next_token.text == '(':
return parse_function_type()
else:
raise Exception()
def parse_function_type() -> FunType:
consume('(')
params: list[Type | list['ParamsType']] = []
if peek().text != ')':
params.append(parse_type())
while peek().text == ',':
consume(',')
params.append(parse_type())
consume(')')
consume('=>')
result = parse_type()
return FunType(params=params, result=result)
def parse_function_def() -> ast.FunctionDef:
consume('fun')
function_name = consume().text
consume('(')
function_args: list[ast.FunctionArg] = []
while peek().text != ')':
arg_name = consume()
consume(':')
arg_type = parse_type()
function_args.append(ast.FunctionArg(arg_name.text, arg_type))
if peek().text == ',':
consume(',')
consume(')')
consume(':')
return_type = parse_type()
block = parse_block()
return ast.FunctionDef(function_name, function_args, return_type, block)
def parse_block() -> ast.Block:
token = consume('{')
# Empty block
if peek().text == '}':
consume('}')
return ast.Block(location=token.location, statements=[])
statements = [parse_expression(allow_var=True)]
last_token = tokens[pos - 1]
while True:
current_peek = peek().text
if current_peek == ';':
consume(';')
if peek().text == '}':
temp_token = consume('}')
statements.append(ast.Literal(location=temp_token.location, value=None))
return ast.Block(location=token.location, statements=statements)
else:
next_exp = parse_expression(allow_var=True)
statements.append(next_exp)
last_token = tokens[pos - 1]
elif current_peek == '{':
if last_token.text == ';' or last_token.text == '}':
block_exp = parse_block()
statements.append(block_exp)
last_token = tokens[pos - 1]
else:
raise Exception()
elif current_peek == '}':
break
else:
if last_token.text == '}':
next_exp = parse_expression(allow_var=True)
statements.append(next_exp)
last_token = tokens[pos - 1]
else:
raise Exception(f'{peek().location}: unexpected token "{current_peek}" after statement ending with "{last_token.text}"')
consume('}')
return ast.Block(location=token.location, statements=statements)
# This is the parsing function for integer literals.
# It checks that we're looking at an integer literal token,
# moves past it, and returns a 'Literal' AST node
# containing the integer from the token.
def parse_int_literal() -> ast.Literal:
if peek().type != 'int_literal':
raise Exception(f'{peek().location}: expected an integer literal')
token = consume()
return ast.Literal(location=token.location, value=int(token.text))
def parse_boolean_literal() -> ast.Literal:
if peek().type != 'bool_literal':
raise Exception(f'{peek().location}: expected an boolean literal')
token = consume()
return ast.Literal(location=token.location, value=True if token.text == 'true' else False)
def parse_identifier() -> ast.Identifier | ast.FunctionCall:
if peek().type != 'identifier':
raise Exception(f'{peek().location}: expected an identifier')
token = consume()
# Function call detected
if peek().text == '(':
consume('(')
args: list[ast.Expression] = []
if peek().text != ')': # Check if there are arguments
args.append(parse_expression())
while peek().text == ',':
consume(',')
args.append(parse_expression())
consume(')')
return ast.FunctionCall(location=token.location, name=str(token.text), args=args)
return ast.Identifier(location=token.location, name=str(token.text))
statements: list[ast.Expression] = []
function_defs: list[ast.FunctionDef] = []
location = tokens[0].location
while peek().text == 'fun':
function_defs.append(parse_function_def())
while peek().type != 'end':
exp = parse_expression(allow_var=True)
statements.append(exp)
if peek().text == ';':
consume(';')
if peek().type == 'end':
statements.append(ast.Literal(location=peek().location, value=None))
else:
if peek().type != 'end':
if last_consumed_token is not None and last_consumed_token.text == '}':
pass
else:
raise Exception(f'{peek().location}: expected ";" or "{{", got "{peek().text}"')
if pos < len(tokens):
raise Exception(f'{peek().location}: unexpected tokens remaining')
return ast.Module(function_defs, ast.Block(location=location, statements=statements))
+88
View File
@@ -0,0 +1,88 @@
import re
from dataclasses import dataclass
from typing import Any
@dataclass
class SourceLocation:
file: str
line: int
column: int
def __eq__(self, value: Any) -> bool:
if value == L:
return True
elif isinstance(value, SourceLocation) and self.file == value.file and self.line == value.line and self.column == value.column:
return True
else:
return False
@dataclass
class SpecialSourceLocation:
def __eq__(self, _: Any) -> bool:
return True
L = SpecialSourceLocation()
type Location = SourceLocation | SpecialSourceLocation
@dataclass
class Token:
location: Location
type: str
text: str
def tokenize(source_code: str) -> list[Token]:
whitespace_re = r'\s'
comment_re = r'(?:\/\/|#)[^\n]*(?=\n|$)'
command_re = r'continue|break|fun|return'
type_re = r'Int|Bool|Unit'
integer_re = r'[0-9]+'
boolean_re = r'true|false'
punctuation_re = r'=>|\(|\)|{|}|,|;|:'
operator_re = r'and|not|or|>=|<=|!=|==|\+|-|\*|/|%|=|<|>'
identifier_re = r'[a-zA-Z_][a-zA-Z0-9_]*'
token_re_list = [
(command_re, 'command'),
(type_re, 'type'),
(integer_re, 'int_literal'),
(boolean_re, 'bool_literal'),
(punctuation_re, 'punctuation'),
(operator_re, 'operator'),
(identifier_re, 'identifier')
]
tokens = []
source_code_lines = source_code.splitlines()
for row_index, line in enumerate(source_code_lines):
col_index = 0
while col_index < len(line):
whitespace_match = re.match(whitespace_re, line[col_index:])
comment_match = re.match(comment_re, line[col_index:])
if whitespace_match:
whitespace = whitespace_match.group(0)
col_index += len(whitespace)
continue
elif comment_match:
break
curr_col_index = col_index
for tup in token_re_list:
rex = tup[0]
rex_type = tup[1]
match = re.match(rex, line[col_index:])
if match:
token = match.group(0) # This is the token that matched.
tokens.append(Token(SourceLocation('', row_index, col_index), rex_type, token))
col_index += len(token)
break
if curr_col_index == col_index:
raise Exception()
return tokens
+301
View File
@@ -0,0 +1,301 @@
from __future__ import annotations
import compiler.ast as ast
from compiler.types import Int, Bool, Unit, Type, FunType, Int_Instance, Bool_Instance, Unit_Instance
from compiler.tokenizer import L
class TypeSymTab:
locals: dict[str, Type | dict[str, FunType]]
parent: TypeSymTab | None
def __init__(self, locals: dict[str, Type | dict[str, FunType] ] = {}, parent: TypeSymTab | None = None) -> None:
self.locals = locals
self.parent = parent
binary_op_common_func_type = FunType(params=[Int_Instance, Int_Instance], result=Int_Instance)
binary_op_int_comparison_func_type = FunType(params=[Int_Instance, Int_Instance], result=Bool_Instance)
binary_op_and_or_func_type = FunType(params=[Bool_Instance, Bool_Instance], result=Bool_Instance)
binary_op_eq_func_type = FunType(params=[[Int_Instance, Bool_Instance], [Int_Instance, Bool_Instance]], result=Bool_Instance)
unary_op_int_func_type = FunType(params=[Int_Instance], result=Int_Instance)
unary_op_bool_func_type = FunType(params=[Bool_Instance], result=Bool_Instance)
print_int_func_type = FunType(params=[Int_Instance], result=Unit_Instance)
print_bool_func_type = FunType(params=[Bool_Instance], result=Unit_Instance)
read_int_func_type = FunType(params=[], result=Int_Instance)
if parent is None:
self.locals.update({
'binary_ops': {
'and': binary_op_and_or_func_type,
'or' : binary_op_and_or_func_type,
'+': binary_op_common_func_type,
'-': binary_op_common_func_type,
'*': binary_op_common_func_type,
'/': binary_op_common_func_type,
'%': binary_op_common_func_type,
'<': binary_op_int_comparison_func_type,
'>': binary_op_int_comparison_func_type,
'>=': binary_op_int_comparison_func_type,
'<=': binary_op_int_comparison_func_type,
'!=': binary_op_eq_func_type,
'==': binary_op_eq_func_type,
},
'unary_ops': {
'-': unary_op_int_func_type,
'not': unary_op_bool_func_type
},
'built_in_functions': {
'print_int': print_int_func_type,
'print_bool': print_bool_func_type,
'read_int': read_int_func_type
}
})
def get_top_level(self) -> TypeSymTab:
if self.parent is None:
return self
return self.parent.get_top_level()
def assign(self, var_name: str, var_type: Type) -> None:
self.locals[var_name] = var_type
def assign_recursive(self, var_name: str, var_type: Type) -> None:
if self.locals.get(var_name) is not None:
self.locals[var_name] = var_type
else:
if self.parent is not None:
self.parent.assign_recursive(var_name, var_type)
else:
raise Exception()
def lookup(self, var_name: str) -> Type:
res = self.locals.get(var_name)
if res is not None:
assert isinstance(res, Int | Bool | Unit | FunType)
return res
else:
raise Exception()
def lookup_recursive(self, var_name: str) -> Type:
res = self.locals.get(var_name)
if res is not None:
assert isinstance(res, Int | Bool | Unit | FunType)
return res
else:
if self.parent is not None:
return self.parent.lookup_recursive(var_name)
else:
raise Exception()
def typecheck_main(node: ast.Expression, sym_tab: TypeSymTab) -> Type:
match node:
case ast.Literal():
if type(node.value) is bool:
return Bool_Instance
elif type(node.value) is int:
return Int_Instance
elif node.value is None:
return Unit_Instance
else:
raise Exception()
case ast.Identifier():
return sym_tab.lookup_recursive(node.name)
case ast.BinaryOp():
t1 = typecheck(node.left, sym_tab)
t2 = typecheck(node.right, sym_tab)
binary_ops = sym_tab.get_top_level().locals['binary_ops']
assert isinstance(binary_ops, dict)
func_type = binary_ops.get(node.op)
if func_type is None:
if node.op != '=':
raise Exception()
# = operator
else:
if not isinstance(node.left, ast.Identifier):
raise Exception()
if t1 != t2:
raise Exception()
return t2
else:
assert isinstance(func_type, FunType)
assert isinstance(func_type.params, list)
if isinstance(func_type.params[0], list):
if not any(t1 == possible_type for possible_type in func_type.params[0]):
raise Exception()
else:
if t1 != func_type.params[0]:
raise Exception()
if isinstance(func_type.params[1], list):
if not any(t2 == possible_type for possible_type in func_type.params[1]):
raise Exception()
else:
if t2 != func_type.params[1]:
raise Exception()
# == and !=
if isinstance(func_type.params[0], list) and isinstance(func_type.params[1], list):
if t1 != t2:
raise Exception()
return func_type.result
case ast.UnaryOp():
t1 = typecheck(node.right, sym_tab)
unary_ops = sym_tab.get_top_level().locals['unary_ops']
assert isinstance(unary_ops, dict)
func_type = unary_ops.get(node.op)
assert isinstance(func_type, FunType)
assert isinstance(func_type.params, list)
if t1 != func_type.params[0]:
raise Exception()
return func_type.result
case ast.If():
e1 = node.cond_exp
e2 = node.then_exp
e3 = node.else_exp
# if-then
if e3 is None:
return Unit_Instance
# if-then-else
else:
t1 = typecheck(e1, sym_tab)
if t1 != Bool_Instance:
raise Exception()
t2 = typecheck(e2, sym_tab)
t3 = typecheck(e3, sym_tab)
if t2 != t3:
raise Exception()
return t2
case ast.While():
if typecheck(node.while_exp, sym_tab) != Bool_Instance:
raise Exception()
return Unit_Instance
case ast.FunctionCall():
built_in_functions = sym_tab.get_top_level().locals['built_in_functions']
assert isinstance(built_in_functions, dict)
func_type = built_in_functions.get(node.name)
# Built-in functions
if func_type is not None:
assert isinstance(func_type, FunType)
assert isinstance(func_type.params, list)
if len(node.args) != len(func_type.params):
raise Exception()
if len(node.args) == 1:
if typecheck(node.args[0], sym_tab) != func_type.params[0]:
raise Exception()
return Unit_Instance
else: # len(node.args) == 0, i.e. node.name == 'read_int'
return Int_Instance
else:
custom_func_type = sym_tab.lookup_recursive(node.name)
assert isinstance(custom_func_type, FunType)
assert isinstance(custom_func_type.params, list)
if len(node.args) != len(custom_func_type.params):
raise Exception()
for i in range(0, len(node.args)):
if typecheck(node.args[i], sym_tab) != custom_func_type.params[i]:
raise Exception()
return custom_func_type.result
case ast.Block():
if len(node.statements) == 0:
return Unit_Instance
last_node = node.statements[len(node.statements) - 1]
if last_node == ast.Literal(location=L, value=None):
new_sym_tab = TypeSymTab(locals={}, parent=sym_tab)
for idx, statement in enumerate(node.statements):
if idx == len(node.statements) - 1:
return Unit_Instance
else:
typecheck(node=statement, sym_tab=new_sym_tab)
raise Exception() # The loop should have returned a result and this line should not have been reached.
else:
new_sym_tab = TypeSymTab(locals={}, parent=sym_tab)
for idx, statement in enumerate(node.statements):
if idx == len(node.statements) - 1:
return typecheck(node=statement, sym_tab=new_sym_tab)
else:
typecheck(node=statement, sym_tab=new_sym_tab)
raise Exception() # The loop should have returned a result and this line should not have been reached.
case ast.Var():
if node.type_f is not None:
value_type = typecheck(node=node.value, sym_tab=sym_tab)
if value_type is Unit_Instance or node.type_f != value_type:
raise Exception()
sym_tab.assign(var_name=node.name, var_type=node.type_f)
return node.type_f
else:
value_type = typecheck(node=node.value, sym_tab=sym_tab)
sym_tab.assign(var_name=node.name, var_type=value_type)
return value_type
case ast.Break():
return Unit_Instance
case ast.Continue():
return Unit_Instance
case ast.Var():
return Unit_Instance
case ast.Return():
assert sym_tab.parent is not None
if typecheck(node.result, sym_tab) != sym_tab.parent.lookup_recursive('return'):
raise Exception()
return Unit_Instance
case _:
raise Exception()
def typecheck(node: ast.Expression, sym_tab: TypeSymTab | None = None) -> Type:
if sym_tab is None:
sym_tab = TypeSymTab()
resulting_type = typecheck_main(node, sym_tab)
node.type = resulting_type
return resulting_type
def typecheck_module(module: ast.Module) -> None:
st_block = TypeSymTab(locals={}, parent=None)
for function_def in module.function_defs:
st_block.assign(function_def.name, FunType(params=[arg.type for arg in function_def.args], result=function_def.return_type))
for function_def in module.function_defs:
st = TypeSymTab(locals={}, parent=st_block)
for arg in function_def.args:
st.assign(arg.name, arg.type)
st.assign('return', function_def.return_type)
typecheck(function_def.block, st)
typecheck(module.block, st_block)
+29
View File
@@ -0,0 +1,29 @@
from __future__ import annotations
from dataclasses import dataclass
class Int:
pass
class Bool:
pass
class Unit:
pass
Int_Instance = Int()
Bool_Instance = Bool()
Unit_Instance = Unit()
class Any(Int, Bool, Unit):
pass
Any_Instance = Any()
type Type = Int | Bool | Unit | FunType
ParamsType = Type | list[Type | list['ParamsType']]
@dataclass
class FunType:
params: ParamsType
result: Type
Executable
+31
View File
@@ -0,0 +1,31 @@
#!/usr/bin/env python3
# Runs the correct test-gadget client program in .test-gadget/test-gadget-client-$PLATFORM
import os
import platform
import sys
from pathlib import Path
def get_platform_binary() -> str:
system = platform.system().lower()
if system == "darwin":
return "test-gadget-client-macos"
elif system == "windows":
return "test-gadget-client-windows.exe"
elif system == "linux":
return "test-gadget-client-linux"
else:
print(f"Unsupported platform: {system}", file=sys.stderr)
sys.exit(1)
script_dir = Path(__file__).parent
dist_dir = script_dir / ".test-gadget"
binary = dist_dir / get_platform_binary()
if not binary.exists():
print(f"Program not found: {binary}", file=sys.stderr)
sys.exit(1)
os.execv(str(binary), [str(binary)] + sys.argv[1:])
View File
+135
View File
@@ -0,0 +1,135 @@
from compiler.assembly_generator import Locals, generate_assembly
from compiler.ir import IRVar, LoadIntConst, LoadBoolConst, Copy, CondJump, Label, Instruction, Call
from compiler.tokenizer import L
from typing import List
def test_assembly_generator_locals_initialization() -> None:
variables = [IRVar('x'), IRVar('y'), IRVar('z')]
locals = Locals(variables)
assert locals.get_ref(variables[0]) == '-8(%rbp)'
assert locals.get_ref(variables[1]) == '-16(%rbp)'
assert locals.get_ref(variables[2]) == '-24(%rbp)'
assert locals.stack_used() == 24 # 3 variables * 8 bytes
# def test_assembly_generator_load_int_const() -> None:
# ir_var = IRVar('x')
# instructions: List[Instruction] = [
# LoadIntConst(L, 42, ir_var)
# ]
# asm = generate_assembly(instructions)
# assert 'movq $42, -8(%rbp)' in asm
# def test_assembly_generator_load_bool_const() -> None:
# instructions: List[Instruction] = [
# LoadBoolConst(L, True, IRVar('a')),
# LoadBoolConst(L, False, IRVar('b'))
# ]
# asm = generate_assembly(instructions)
# assert 'movq $1, -8(%rbp)' in asm # True
# assert 'movq $0, -16(%rbp)' in asm # False
# def test_assembly_generator_copy() -> None:
# src = IRVar('src')
# dest = IRVar('dest')
# instructions: List[Instruction] = [
# Copy(L, src, dest)
# ]
# asm = generate_assembly(instructions)
# assert 'movq -8(%rbp), %rax' in asm
# assert 'movq %rax, -16(%rbp)' in asm
# def test_assembly_generator_cond_jump() -> None:
# cond_var = IRVar('cond')
# then_label = Label(L, 'Lthen')
# else_label = Label(L, 'Lelse')
# instructions: List[Instruction] = [
# CondJump(L, cond_var, then_label, else_label)
# ]
# asm = generate_assembly(instructions)
# assert 'cmpq $0, -8(%rbp)' in asm
# assert 'jne .LLthen' in asm
# assert 'jmp .LLelse' in asm
# def test_assembly_generator_function_prologue_epilogue() -> None:
# instructions: List[Instruction] = [
# LoadIntConst(L, 0, IRVar('dummy'))
# ]
# asm = generate_assembly(instructions)
# assert 'pushq %rbp' in asm
# assert 'movq %rsp, %rbp' in asm
# assert 'subq $8, %rsp' in asm
# assert 'movq %rbp, %rsp' in asm
# assert 'popq %rbp' in asm
# assert 'ret' in asm
# def test_assembly_generator_intrinsic_plus() -> None:
# # IR: Call('+', [x, y], result)
# x = IRVar('x')
# y = IRVar('y')
# result = IRVar('result')
# instructions = [
# LoadIntConst(L, 3, x),
# LoadIntConst(L, 5, y),
# Call(L, IRVar('+'), [x, y], result)
# ]
# asm = generate_assembly(instructions)
# assert 'addq' in asm
# assert 'movq' in asm
# assert 'callq' not in asm # Intrinsic should not use call
# def test_assembly_generator_intrinsic_divide() -> None:
# # IR: Call('/', [a, b], result)
# a = IRVar('a')
# b = IRVar('b')
# instructions = [
# LoadIntConst(L, 10, a),
# LoadIntConst(L, 2, b),
# Call(L, IRVar('/'), [a, b], IRVar('result'))
# ]
# asm = generate_assembly(instructions)
# assert 'idivq' in asm
# assert 'cqto' in asm
# def test_assembly_generator_function_call_one_arg() -> None:
# # IR: Call(print_int, [x], _)
# x = IRVar('x')
# instructions = [
# LoadIntConst(L, 42, x),
# Call(L, IRVar('print_int'), [x], IRVar('unused'))
# ]
# asm = generate_assembly(instructions)
# assert 'movq -8(%rbp), %rdi' in asm # Assuming x is at -8(%rbp)
# assert 'callq print_int' in asm
# def test_assembly_generator_function_call_six_args() -> None:
# # IR: Call(func, [a,b,c,d,e,f], _)
# args = [IRVar(f'arg{i}') for i in range(6)]
# instructions = [
# *[LoadIntConst(L, i, arg) for i, arg in enumerate(args)],
# Call(L, IRVar('func'), args, IRVar('result'))
# ]
# asm = generate_assembly(instructions)
# expected_regs = ['%rdi', '%rsi', '%rdx', '%rcx', '%r8', '%r9']
# for i, reg in enumerate(expected_regs):
# assert f'movq -{8*(i+1)}(%rbp), {reg}' in asm
# assert 'callq func' in asm
# def test_assembly_generator_comparison_intrinsic() -> None:
# # IR: Call('==', [x, y], result)
# x = IRVar('x')
# y = IRVar('y')
# instructions = [
# LoadIntConst(L, 5, x),
# LoadIntConst(L, 5, y),
# Call(L, IRVar('=='), [x, y], IRVar('result'))
# ]
# asm = generate_assembly(instructions)
# assert 'cmpq' in asm
# assert 'sete' in asm
+419
View File
@@ -0,0 +1,419 @@
import pytest
from typing import Any
from compiler.parser import parse
from compiler.tokenizer import tokenize
from compiler.interpreter import interpret, SymTab, Unit
def test_interpreter_addition() -> None:
exp = '2 + 3'
tokens = tokenize(exp)
ast = parse(tokens)
assert interpret(ast.block) == 5
def test_interpreter_boolean_coercion() -> None:
exp = '1 + (2 < 3)'
tokens = tokenize(exp)
ast = parse(tokens)
assert interpret(ast.block) == 2
def test_interpreter_nested_blocks() -> None:
code = '''
{
var x = 5;
{
var y = 10;
x + y
}
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 15
def test_interpreter_variable_shadowing() -> None:
code = '''
var x = 1;
{
var x = 2;
x
}
x
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 1
def test_interpreter_assignment_in_outer_scope() -> None:
code = '''
var x = 5;
{
x = x + 1;
x
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 6
def test_interpreter_assignment_operator() -> None:
code = 'var x = 0; x = 5; x'
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 5
def test_interpreter_operator_precedence() -> None:
code = '10 - 4 - 3' # (10-4)-3 = 3
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 3
def test_interpreter_logical_operators() -> None:
code = 'true and false or true' # (true and false) or true = true
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_comparison_equality() -> None:
code = '5 > 3 == true' # (5>3) == true → true
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_unary_operators() -> None:
code = 'not (5 < 3)' # not(false) → true
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_block_scoping() -> None:
code = '''
{
var a = 10;
{
var b = 20;
a + b
}
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 30
def test_interpreter_variable_redeclaration() -> None:
code = '''
{
var x = 1;
var x = 2;
x
}
'''
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_while_loop() -> None:
code = '''
var counter = 3;
while counter > 0 do {
counter = counter - 1
};
counter
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 0
def test_interpreter_empty_block() -> None:
code = '{}'
tokens = tokenize(code)
ast = parse(tokens)
assert isinstance(interpret(ast.block), Unit)
def test_interpreter_block_last_semicolon() -> None:
code = '''
{
var x = 5;
x;
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert isinstance(interpret(ast.block), Unit)
def test_interpreter_if_without_else_returns_unit() -> None:
code = 'if false then 5'
tokens = tokenize(code)
ast = parse(tokens)
assert isinstance(interpret(ast.block), Unit)
def test_interpreter_if_else_true_branch() -> None:
code = 'if true then 10 else 20'
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 10
def test_interpreter_if_else_false_branch() -> None:
code = 'if false then 10 else 20'
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 20
def test_interpreter_function_call_raises_error(capsys: pytest.CaptureFixture) -> None:
code = 'print_int(5)'
tokens = tokenize(code)
ast = parse(tokens)
assert isinstance(interpret(ast.block), Unit)
captured = capsys.readouterr()
assert captured.out == '5\n'
def test_interpreter_division_modulo_operators() -> None:
code = '8 / 3 + 8 % 3' # 2 + 2 = 4
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 4
def test_interpreter_assignment_to_non_identifier_raises() -> None:
code = '5 = 10'
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_while_loop_zero_iterations() -> None:
code = '''
var x = 3;
while x > 5 do {
x = x - 1
};
x
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 3
def test_interpreter_nested_if_expression() -> None:
code = '(if true then 5 else 3) + 2'
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 7
def test_interpreter_unary_not_on_boolean() -> None:
code = 'not true'
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is False
def test_interpreter_unary_negate_non_int_raises() -> None:
code = '-true'
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_and_short_circuit() -> None:
code = '''
var evaluated = false;
false and { evaluated = true; true };
evaluated
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is False
def test_interpreter_and_evaluates_both() -> None:
code = '''
var evaluated = false;
true and { evaluated = true; true };
evaluated
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_or_short_circuit() -> None:
code = '''
var evaluated = false;
true or { evaluated = true; true };
evaluated
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is False
def test_interpreter_or_evaluates_both() -> None:
code = '''
var evaluated = false;
false or { evaluated = true; true };
evaluated
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_unary_negate_int() -> None:
code = '-5'
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == -5
def test_interpreter_unary_negate_bool_raises() -> None:
code = '-true'
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_unary_not_on_int_raises() -> None:
code = 'not 5'
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_binary_operator_precedence() -> None:
code = '1 + 2 * 3 == 7' # 1 + (2 * 3) == 7 → true
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_binary_operator_associativity() -> None:
code = '10 - 4 - 3' # (10 - 4) - 3 = 3
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 3
def test_interpreter_binary_operator_equality() -> None:
code = '5 == 5 and 3 != 4' # true and true → true
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_binary_operator_comparison() -> None:
code = '5 < 10 and 10 >= 10' # true and true → true
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) is True
def test_interpreter_binary_operator_division_by_zero_raises() -> None:
code = '5 / 0'
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_binary_operator_modulo_by_zero_raises() -> None:
code = '5 % 0'
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_while_loop_multiple_iterations() -> None:
code = '''
var x = 3;
while x > 0 do {
x = x - 1
};
x
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 0
def test_interpreter_block_with_multiple_statements() -> None:
code = '''
{
var x = 5;
var y = 10;
x + y
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 15
def test_interpreter_block_with_trailing_semicolon() -> None:
code = '''
{
var x = 5;
x;
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert isinstance(interpret(ast.block), Unit)
def test_interpreter_block_with_nested_blocks() -> None:
code = '''
{
var x = 5;
{
var y = 10;
x + y
}
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 15
def test_interpreter_variable_declaration_in_block() -> None:
code = '''
{
var x = 5;
x
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 5
def test_interpreter_variable_redeclaration_in_block_raises() -> None:
code = '''
{
var x = 5;
var x = 10;
x
}
'''
tokens = tokenize(code)
ast = parse(tokens)
with pytest.raises(Exception):
interpret(ast.block)
def test_interpreter_assignment_in_block() -> None:
code = '''
{
var x = 5;
x = x + 1;
x
}
'''
tokens = tokenize(code)
ast = parse(tokens)
assert interpret(ast.block) == 6
def test_interpreter_function_call_print_int() -> None:
code = 'print_int(5)'
tokens = tokenize(code)
ast = parse(tokens)
assert isinstance(interpret(ast.block), Unit)
def test_interpreter_function_call_print_bool() -> None:
code = 'print_bool(true)'
tokens = tokenize(code)
ast = parse(tokens)
assert isinstance(interpret(ast.block), Unit)
def test_interpreter_function_call_read_int() -> None:
code = 'read_int()'
tokens = tokenize(code)
ast = parse(tokens)
# Mock input for read_int
import builtins
original_input: Any = builtins.input
func: Any = lambda: '42'
builtins.input = func
assert interpret(ast.block) == 42
builtins.input = original_input
+12
View File
@@ -0,0 +1,12 @@
from compiler.type_checker import typecheck_module
from compiler.tokenizer import L, tokenize
from compiler.parser import parse
from compiler.ir_generator import generate_ir_from_module, print_instructions, root_types
def test_ir_generator_basic () -> None:
expr_str = '1 + 2 * 3'
tokens = tokenize(expr_str)
ast = parse(tokens)
typecheck_module(ast)
main_instructions = generate_ir_from_module(ast)['main']
assert print_instructions(main_instructions) != ''
+2505
View File
File diff suppressed because it is too large Load Diff
+28
View File
@@ -0,0 +1,28 @@
from compiler.tokenizer import tokenize, Token, L
def test_tokenizer_basic() -> None:
assert tokenize('aaa 123 bbb') == [
Token(location=L, type='identifier', text='aaa'),
Token(location=L, type='int_literal', text='123'),
Token(location=L, type='identifier', text='bbb'),
]
def test_tokenizer_newline() -> None:
assert tokenize('if 3\nwhile') == [
Token(location=L, type='identifier', text='if'),
Token(location=L, type='int_literal', text='3'),
Token(location=L, type='identifier', text='while'),
]
def test_tokenizer_commments() -> None:
assert tokenize('aaa 123 bbb ; ) ( >= \n ) # aksdjalksjdkajskdjasd\n != // Another comment $') == [
Token(location=L, type='identifier', text='aaa'),
Token(location=L, type='int_literal', text='123'),
Token(location=L, type='identifier', text='bbb'),
Token(location=L, type='punctuation', text=';'),
Token(location=L, type='punctuation', text=')'),
Token(location=L, type='punctuation', text='('),
Token(location=L, type='operator', text='>='),
Token(location=L, type='punctuation', text=')'),
Token(location=L, type='operator', text='!='),
]
+921
View File
@@ -0,0 +1,921 @@
import pytest
from compiler.type_checker import typecheck, TypeSymTab
from compiler import ast
from compiler.types import Int_Instance, Bool_Instance, Unit_Instance, FunType
from compiler.tokenizer import L
def test_type_checker_assignment_with_unknown_variable() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.Literal(L, 5)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_literal_int() -> None:
node = ast.Literal(location=L, value=5)
result = typecheck(node)
assert result == Int_Instance
assert node.type == Int_Instance
def test_type_checker_literal_bool() -> None:
node = ast.Literal(location=L, value=True)
result = typecheck(node)
assert result == Bool_Instance
assert node.type == Bool_Instance
def test_type_checker_literal_unit() -> None:
node = ast.Literal(location=L, value=None)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_variable_lookup() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.Identifier(location=L, name='x')
result = typecheck(node, sym_tab)
assert result == Int_Instance
assert node.type == Int_Instance
def test_type_checker_variable_undefined() -> None:
node = ast.Identifier(location=L, name='x')
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_binary_op_add_valid() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, 1),
op='+',
right=ast.Literal(L, 2)
)
result = typecheck(node)
assert result == Int_Instance
assert node.type == Int_Instance
def test_type_checker_binary_op_add_invalid() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, True),
op='+',
right=ast.Literal(L, 2)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_unary_op_negate_valid() -> None:
node = ast.UnaryOp(
location=L,
op='-',
right=ast.Literal(L, 5)
)
result = typecheck(node)
assert result == Int_Instance
assert node.type == Int_Instance
def test_type_checker_unary_op_not_valid() -> None:
node = ast.UnaryOp(
location=L,
op='not',
right=ast.Literal(L, False)
)
result = typecheck(node)
assert result == Bool_Instance
assert node.type == Bool_Instance
def test_type_checker_function_call_builtin() -> None:
node = ast.FunctionCall(
location=L,
name='print_int',
args=[ast.Literal(L, 5)]
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_block_last_expr() -> None:
node = ast.Block(
location=L,
statements=[
ast.Literal(L, 1),
ast.Literal(L, 2)
]
)
result = typecheck(node)
assert result == Int_Instance
assert node.statements[-1].type == Int_Instance
def test_type_checker_block_unit() -> None:
node = ast.Block(
location=L,
statements=[
ast.Literal(L, 1),
ast.Literal(L, None)
]
)
result = typecheck(node)
assert result == Unit_Instance
assert node.statements[-1].type == Unit_Instance
def test_type_checker_if_then_else_valid() -> None:
node = ast.If(
location=L,
cond_exp=ast.Literal(L, True),
then_exp=ast.Literal(L, 1),
else_exp=ast.Literal(L, 2)
)
result = typecheck(node)
assert result == Int_Instance
assert node.then_exp.type == Int_Instance
def test_type_checker_if_then_else_mismatch() -> None:
node = ast.If(
location=L,
cond_exp=ast.Literal(L, True),
then_exp=ast.Literal(L, 1),
else_exp=ast.Literal(L, False)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_while_loop() -> None:
node = ast.While(
location=L,
while_exp=ast.Literal(L, True),
do_exp=ast.Literal(L, 5)
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_assignment_valid() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.Literal(L, 5)
)
result = typecheck(node, sym_tab)
assert result == Int_Instance
assert node.right.type == Int_Instance
def test_type_checker_assignment_mismatch() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.Literal(L, False)
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_var_declaration_inferred() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, 5),
type_f=None
)
result = typecheck(node)
assert result == Int_Instance
assert node.value.type == Int_Instance
def test_type_checker_var_declaration_typed_valid() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, 5),
type_f=Int_Instance
)
result = typecheck(node)
assert result == Int_Instance
assert node.value.type == Int_Instance
def test_type_checker_var_declaration_typed_mismatch() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, True),
type_f=Int_Instance
)
with pytest.raises(Exception):
typecheck(node)
# def test_type_checker_function_type_params() -> None:
# node = ast.FunctionCall(
# location=L,
# name='+',
# args=[
# ast.Literal(L, 1),
# ast.Literal(L, 2)
# ]
# )
# result = typecheck(node)
# assert result == Int_Instance
# assert node.type == Int_Instance
def test_type_checker_symbol_table_hierarchy() -> None:
parent = TypeSymTab(locals={'x': Int_Instance})
child = TypeSymTab(parent=parent)
node = ast.Identifier(L, 'x')
result = typecheck(node, child)
assert result == Int_Instance
assert node.type == Int_Instance
def test_type_checker_equality_operator_valid() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, 5),
op='==',
right=ast.Literal(L, 5)
)
result = typecheck(node)
assert result == Bool_Instance
assert node.type == Bool_Instance
def test_type_checker_equality_operator_mismatch() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, 5),
op='==',
right=ast.Literal(L, False)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_function_call_user_defined() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[Int_Instance], result=Bool_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[ast.Literal(L, 5)]
)
result = typecheck(node, sym_tab)
assert result == Bool_Instance
assert node.type == Bool_Instance
def test_type_checker_nested_blocks() -> None:
node = ast.Block(
location=L,
statements=[
ast.Block(
location=L,
statements=[
ast.Literal(L, 5),
ast.Block(
location=L,
statements=[
ast.Literal(L, True)
]
)
]
)
]
)
result = typecheck(node)
assert result == Bool_Instance
assert isinstance(node.statements[0], ast.Block)
assert node.statements[0].statements[-1].type == Bool_Instance
def test_type_checker_function_type_annotation() -> None:
node = ast.Var(
location=L,
name='f',
value=ast.Literal(L, 5),
type_f=FunType(params=[Int_Instance], result=Bool_Instance)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_function_call_with_wrong_arg_count() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[Int_Instance, Int_Instance], result=Bool_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[ast.Literal(L, 5)]
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_function_call_with_wrong_arg_type() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[Int_Instance], result=Bool_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[ast.Literal(L, True)]
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_function_call_with_nested_args() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[Int_Instance, Bool_Instance], result=Bool_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[
ast.Literal(L, 5),
ast.BinaryOp(
location=L,
left=ast.Literal(L, True),
op='and',
right=ast.Literal(L, False)
)
]
)
result = typecheck(node, sym_tab)
assert result == Bool_Instance
assert node.type == Bool_Instance
def test_type_checker_function_call_with_builtin_print_int() -> None:
node = ast.FunctionCall(
location=L,
name='print_int',
args=[ast.Literal(L, 5)]
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_function_call_with_builtin_read_int() -> None:
node = ast.FunctionCall(
location=L,
name='read_int',
args=[]
)
result = typecheck(node)
assert result == Int_Instance
assert node.type == Int_Instance
def test_type_checker_function_call_with_builtin_print_bool() -> None:
node = ast.FunctionCall(
location=L,
name='print_bool',
args=[ast.Literal(L, True)]
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_function_call_with_unknown_function() -> None:
node = ast.FunctionCall(
location=L,
name='unknown_func',
args=[ast.Literal(L, 5)]
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_if_then_without_else() -> None:
node = ast.If(
location=L,
cond_exp=ast.Literal(L, True),
then_exp=ast.Literal(L, 5),
else_exp=None
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_if_then_with_else() -> None:
node = ast.If(
location=L,
cond_exp=ast.Literal(L, True),
then_exp=ast.Literal(L, 5),
else_exp=ast.Literal(L, 10)
)
result = typecheck(node)
assert result == Int_Instance
assert node.type == Int_Instance
def test_type_checker_if_then_with_else_mismatch() -> None:
node = ast.If(
location=L,
cond_exp=ast.Literal(L, True),
then_exp=ast.Literal(L, 5),
else_exp=ast.Literal(L, False)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_while_loop_with_non_bool_condition() -> None:
node = ast.While(
location=L,
while_exp=ast.Literal(L, 5),
do_exp=ast.Literal(L, 10)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_while_loop_with_unit_body() -> None:
node = ast.While(
location=L,
while_exp=ast.Literal(L, True),
do_exp=ast.Literal(L, None)
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_block_with_multiple_statements() -> None:
node = ast.Block(
location=L,
statements=[
ast.Literal(L, 5),
ast.Literal(L, True),
ast.Literal(L, None)
]
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_block_with_nested_blocks() -> None:
node = ast.Block(
location=L,
statements=[
ast.Block(
location=L,
statements=[
ast.Literal(L, 5),
ast.Block(
location=L,
statements=[
ast.Literal(L, True)
]
)
]
)
]
)
result = typecheck(node)
assert result == Bool_Instance
assert node.type == Bool_Instance
def test_type_checker_block_with_empty_statements() -> None:
node = ast.Block(
location=L,
statements=[]
)
result = typecheck(node)
assert result == Unit_Instance
assert node.type == Unit_Instance
def test_type_checker_var_declaration_with_nested_expression() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.BinaryOp(
location=L,
left=ast.Literal(L, 5),
op='+',
right=ast.Literal(L, 10)
),
type_f=None
)
result = typecheck(node)
assert result == Int_Instance
assert node.value.type == Int_Instance
def test_type_checker_var_declaration_with_typed_initializer() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, 5),
type_f=Int_Instance
)
result = typecheck(node)
assert result == Int_Instance
assert node.value.type == Int_Instance
def test_type_checker_var_declaration_with_typed_initializer_mismatch() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, True),
type_f=Int_Instance
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_var_declaration_with_untyped_initializer() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, 5),
type_f=None
)
result = typecheck(node)
assert result == Int_Instance
assert node.value.type == Int_Instance
def test_type_checker_var_declaration_with_function_type() -> None:
node = ast.Var(
location=L,
name='f',
value=ast.Literal(L, 5),
type_f=FunType(params=[Int_Instance], result=Bool_Instance)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_assignment_with_function_call() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.FunctionCall(
location=L,
name='read_int',
args=[]
)
)
result = typecheck(node, sym_tab)
assert result == Int_Instance
assert node.right.type == Int_Instance
def test_type_checker_assignment_with_function_call_mismatch() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.FunctionCall(
location=L,
name='print_int',
args=[ast.Literal(L, 5)]
)
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_assignment_with_nested_expression() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.BinaryOp(
location=L,
left=ast.Literal(L, 5),
op='+',
right=ast.Literal(L, 10)
)
)
result = typecheck(node, sym_tab)
assert result == Int_Instance
assert node.right.type == Int_Instance
def test_type_checker_assignment_with_nested_expression_mismatch() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.BinaryOp(
location=L,
left=ast.Literal(L, 5),
op='+',
right=ast.Literal(L, True)
)
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_assignment_with_wrong_type() -> None:
sym_tab = TypeSymTab(locals={'x': Int_Instance})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.Literal(L, True)
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_assignment_with_function_type() -> None:
sym_tab = TypeSymTab(locals={'f': FunType(params=[Int_Instance], result=Bool_Instance)})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'f'),
op='=',
right=ast.Literal(L, 5)
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_assignment_with_function_call_returning_function() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[], result=FunType(params=[Int_Instance], result=Bool_Instance))
})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.FunctionCall(
location=L,
name='f',
args=[]
)
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_assignment_with_function_call_returning_int() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[], result=Int_Instance),
'x': Int_Instance
})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.FunctionCall(
location=L,
name='f',
args=[]
)
)
result = typecheck(node, sym_tab)
assert result == Int_Instance
assert node.right.type == Int_Instance
def test_type_checker_assignment_with_function_call_returning_bool() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[], result=Bool_Instance),
'x': Int_Instance
})
node = ast.BinaryOp(
location=L,
left=ast.Identifier(L, 'x'),
op='=',
right=ast.FunctionCall(
location=L,
name='f',
args=[]
)
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
import pytest
from compiler.type_checker import typecheck, TypeSymTab
from compiler import ast
from compiler.types import Int_Instance, Bool_Instance, Unit_Instance, FunType
from compiler.tokenizer import L
def test_type_checker_var_typed_declaration_valid() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, 5),
type_f=Int_Instance
)
result = typecheck(node)
assert result == Int_Instance
assert node.value.type == Int_Instance
def test_type_checker_var_typed_declaration_mismatch() -> None:
node = ast.Var(
location=L,
name='x',
value=ast.Literal(L, True),
type_f=Int_Instance
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_assignment_non_identifier() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, 5), # Invalid left side
op='=',
right=ast.Literal(L, 10)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_user_function_call_args_count_mismatch() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[Int_Instance, Int_Instance], result=Int_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[ast.Literal(L, 5)] # Missing second arg
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_user_function_call_arg_type_mismatch() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[Int_Instance], result=Bool_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[ast.Literal(L, True)] # Bool instead of Int
)
with pytest.raises(Exception):
typecheck(node, sym_tab)
def test_type_checker_if_then_returns_unit() -> None:
node = ast.If(
location=L,
cond_exp=ast.Literal(L, True),
then_exp=ast.Literal(L, 5),
else_exp=None
)
result = typecheck(node)
assert result == Unit_Instance
def test_type_checker_equality_mixed_types() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, 5),
op='==',
right=ast.Literal(L, False)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_logical_op_non_boolean() -> None:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, 5),
op='and',
right=ast.Literal(L, 0)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_nested_scope_shadowing() -> None:
outer = TypeSymTab(locals={'x': Int_Instance})
inner = TypeSymTab(parent=outer)
inner.locals['x'] = Bool_Instance # Shadow outer x
node = ast.Identifier(L, 'x')
result = typecheck(node, inner)
assert result == Bool_Instance
def test_type_checker_function_type_multiple_params() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[Int_Instance, Bool_Instance], result=Unit_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[
ast.Literal(L, 5),
ast.Literal(L, True)
]
)
result = typecheck(node, sym_tab)
assert result == Unit_Instance
def test_type_checker_read_int_builtin() -> None:
node = ast.FunctionCall(
location=L,
name='read_int',
args=[]
)
result = typecheck(node)
assert result == Int_Instance
def test_type_checker_block_trailing_semicolon_returns_unit() -> None:
node = ast.Block(
location=L,
statements=[
ast.Literal(L, 5),
ast.Literal(L, None) # Trailing semicolon case
]
)
result = typecheck(node)
assert result == Unit_Instance
def test_type_checker_not_operator_non_boolean() -> None:
node = ast.UnaryOp(
location=L,
op='not',
right=ast.Literal(L, 5) # Int instead of Bool
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_equality_operator_same_types() -> None:
for val in [5, True]:
node = ast.BinaryOp(
location=L,
left=ast.Literal(L, val),
op='==',
right=ast.Literal(L, val)
)
result = typecheck(node)
assert result == Bool_Instance
def test_type_checker_function_type_nested_params() -> None:
# Test function type with complex parameter structure
fun_type = FunType(
params=[[Int_Instance, Bool_Instance], Int_Instance], # From spec's TODO
result=Unit_Instance
)
sym_tab = TypeSymTab(locals={'f': fun_type})
# Valid call with Int then Int (first param allows Int or Bool)
node = ast.FunctionCall(
location=L,
name='f',
args=[ast.Literal(L, 5), ast.Literal(L, 5)]
)
with pytest.raises(Exception): # Second arg should be Int, but first allows Int
typecheck(node, sym_tab) # Actual logic may vary based on implementation
def test_type_checker_operator_precedence_type_resolution() -> None:
# Ensure operator precedence doesn't affect type checking
node = ast.BinaryOp(
location=L,
left=ast.BinaryOp(
location=L,
left=ast.Literal(L, 3),
op='*',
right=ast.Literal(L, 4)
),
op='+',
right=ast.Literal(L, 5)
)
result = typecheck(node)
assert result == Int_Instance
def test_type_checker_while_loop_condition_non_boolean() -> None:
node = ast.While(
location=L,
while_exp=ast.Literal(L, 5), # Non-bool condition
do_exp=ast.Literal(L, 0)
)
with pytest.raises(Exception):
typecheck(node)
def test_type_checker_builtin_print_bool_type() -> None:
node = ast.FunctionCall(
location=L,
name='print_bool',
args=[ast.Literal(L, True)]
)
result = typecheck(node)
assert result == Unit_Instance
def test_type_checker_function_return_type_mismatch() -> None:
sym_tab = TypeSymTab(locals={
'f': FunType(params=[], result=Int_Instance)
})
node = ast.FunctionCall(
location=L,
name='f',
args=[]
)
result = typecheck(node, sym_tab)
assert result == Int_Instance # Dummy test to be replaced with actual logic
def test_type_checker_complex_function_type_annotation() -> None:
# Test type annotation like (Int, (Bool) => Unit) => Int
fun_type = FunType(
params=[
Int_Instance,
FunType(params=[Bool_Instance], result=Unit_Instance)
],
result=Int_Instance
)
sym_tab = TypeSymTab(locals={'f': fun_type})
node = ast.FunctionCall(
location=L,
name='f',
args=[
ast.Literal(L, 5),
ast.Identifier(L, 'g') # Assume 'g' has correct type
]
)
# Actual test would require 'g' to be in symbol table
with pytest.raises(Exception): # 'g' not defined here
typecheck(node, sym_tab)
def test_type_checker_undefined_function_call() -> None:
node = ast.FunctionCall(
location=L,
name='undefined_func',
args=[]
)
with pytest.raises(Exception):
typecheck(node)