Files
ast-project/part1/fuzzer.py
T
2026-06-24 13:47:14 +02:00

318 lines
8.8 KiB
Python

#!/usr/bin/env python3
import argparse
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
import datetime
import multiprocessing
import os
import random
import signal
import time
from mutator import mutate
from mutator_extra_statements import (
mut_prepend_random_meta_command,
mut_prepend_random_pragma,
)
from oracle import check
from runner import reset_coverage
from stats import create_stats_report
class _MutateTimeout(Exception):
pass
def _timeout_handler(signum, frame):
raise _MutateTimeout()
_HAS_SIGALRM = hasattr(signal, 'SIGALRM') and hasattr(signal, 'setitimer')
if _HAS_SIGALRM:
signal.signal(signal.SIGALRM, _timeout_handler)
def mutate_safe(query: str, timeout: float = 0.5) -> str:
"""
Runs mutate(query) with a real, in-process timeout (no threads).
Returns the mutated query if successful, or the original on timeout/error.
"""
if not _HAS_SIGALRM:
try:
return mutate(query)
except Exception:
return query
signal.setitimer(signal.ITIMER_REAL, timeout)
try:
result = mutate(query)
except _MutateTimeout:
return query
except Exception:
return query
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
return result if isinstance(result, str) else query
def _load_seeds(seeds_dir: str) -> list[str]:
"""Load and deduplicate seeds from disk + additional.sql"""
seeds: list[str] = []
seen: set[str] = set()
if os.path.isdir(seeds_dir):
for fname in os.listdir(seeds_dir):
try:
with open(os.path.join(seeds_dir, fname), 'r', encoding='latin-1') as f:
content = f.read()
except OSError:
continue
if content and content not in seen:
seen.add(content)
seeds.append(content)
try:
with open("additional.sql", "r", encoding='latin-1') as fh:
raw = fh.read()
for q in raw.split('---'):
q = q.strip()
if q and q not in seen:
seen.add(q)
seeds.append(q)
except OSError:
pass
return seeds
def _load_seeds_disk_only(seeds_dir: str) -> list[str]:
"""Load seeds from disk"""
seeds: list[str] = []
if os.path.isdir(seeds_dir):
for fname in os.listdir(seeds_dir):
try:
with open(os.path.join(seeds_dir, fname), 'r', encoding='latin-1') as f:
content = f.read()
except OSError:
continue
seeds.append(content)
return seeds
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--seeds',
default='/home/test/seeds',
help='Directory with seed .sql files')
parser.add_argument('--buggy',
default='/home/test/sqlite3-src/build/sqlite3',
help='Buggy SQLite binary')
parser.add_argument('--reference',
default='/usr/bin/sqlite3',
help='Reference SQLite binary')
parser.add_argument('--count',
type=int,
default=10000,
help='Number of queries to generate')
parser.add_argument('--mutate-timeout',
type=float,
default=0.5,
help='Per-mutation timeout in seconds (default 0.5)')
parser.add_argument(
'--max-query-length',
type=int,
default=100000,
help='Hard cap on a single query in characters (default 100000). Queries '
'that grow past this after mutation are reset back to the original '
'seed for that iteration.')
parser.add_argument('--workers',
type=int,
default=multiprocessing.cpu_count(),
help="""Number of parallel check() workers.""")
parser.add_argument(
'--validate-seeds',
action='store_true',
default=False,
help=
'Run the upfront seed-validation pass (~2N subprocess calls before the main loop). Off by default.'
)
parser.add_argument(
'--run-baseline',
action='store_true',
default=False,
help=
'Run the seeds inside /home/test/seeds only, without applying any mutation.'
)
args = parser.parse_args()
# # #################################################################################################################
if args.run_baseline:
reset_coverage()
seeds = _load_seeds_disk_only(args.seeds)
queries: list[str] = list(seeds) # pyright: ignore[reportRedeclaration]
print(f'{datetime.datetime.now()}: [INFO] Loaded {len(queries)} seeds.')
print(
f'{datetime.datetime.now()}: [INFO] Running {len(queries)} iterations...')
invalid_queries = 0
loop_start = time.perf_counter()
for query in queries:
is_ok = check(args.buggy, args.reference, query, False)
if is_ok == -1:
invalid_queries += 1
loop_end = time.perf_counter()
loop_elapsed = loop_end - loop_start
valid_queries = len(queries) - invalid_queries
print()
create_stats_report(valid_queries,
invalid_queries,
queries,
qpm_gen=-1,
qpm_full=-1,
wall_seconds=loop_elapsed)
else:
seeds = _load_seeds(args.seeds)
queries: list[str] = list(seeds)
if args.validate_seeds:
print(f'{datetime.datetime.now()}: [INFO] Validating seeds...')
queries = []
for seed in seeds:
if check(args.buggy, args.reference, seed, save_bugs=False) != -1:
queries.append(seed)
variants: list[str] = []
for seed in seeds:
modified = mut_prepend_random_pragma(
mut_prepend_random_pragma(
mut_prepend_random_pragma(mut_prepend_random_meta_command(seed))))
variants.append(modified)
queries.extend(variants)
print(f'{datetime.datetime.now()}: [INFO] Loaded {len(queries)} seeds.')
print(
f'{datetime.datetime.now()}: [INFO] Running {args.count} iterations with {args.workers} workers...'
)
reset_coverage()
invalid_queries = 0
actually_executed_queries: list[str] = []
count = int(args.count)
phase1_end = min(300, count)
phase2_end = max(phase1_end, count - 1000)
gen_time_total = 0.0
loop_start = time.perf_counter()
pending: dict[Future[int], tuple[int, str, str, bool]] = {}
pool = ThreadPoolExecutor(max_workers=args.workers)
submitted = 0
def _submit_next():
"""Mutate one query and submit a check() to the worker pool"""
nonlocal gen_time_total, submitted
idx = random.randint(0, len(queries) - 1)
original = queries[idx]
with_flag = (submitted < phase1_end) or (submitted >= phase2_end)
start = time.perf_counter()
query = mutate_safe(original, timeout=args.mutate_timeout)
if len(query) > args.max_query_length:
query = original
end = time.perf_counter()
gen_time_total += end - start
fut = pool.submit(check, args.buggy, args.reference, query, with_flag)
pending[fut] = (idx, query, original, with_flag)
submitted += 1
def _drain_one(fut: Future):
"""Collect one completed future and update state"""
nonlocal invalid_queries
idx, query, original, _with_flag = pending.pop(fut)
try:
is_ok = fut.result()
except Exception:
is_ok = -1
actually_executed_queries.append(query)
if is_ok == -1:
invalid_queries += 1
elif query != original and idx < len(
queries) and queries[idx] == original:
queries[idx] = query
n = len(actually_executed_queries)
if n % 100 == 0:
print(f'{datetime.datetime.now()}: [INFO] Executed {n} queries...')
for _ in range(min(int(args.workers), count)):
_submit_next()
while len(actually_executed_queries) < count:
done, _ = wait(pending.keys(), return_when=FIRST_COMPLETED)
for fut in done:
_drain_one(fut)
if submitted < count:
_submit_next()
if len(actually_executed_queries) >= count:
break
pool.shutdown(wait=True)
loop_end = time.perf_counter()
loop_elapsed = loop_end - loop_start
valid_queries = len(actually_executed_queries) - invalid_queries
qpm_gen = (count / gen_time_total * 60.0) if gen_time_total > 0 else 0.0
qpm_full = (count / loop_elapsed * 60.0) if loop_elapsed > 0 else 0.0
print()
print(f'Wall clock: {loop_elapsed:.2f}s for {count} iterations')
print(f'Generation-only time: {gen_time_total:.2f}s '
f'(≈ {gen_time_total / count * 1000:.2f} ms/query)')
print(f'Throughput (generated): {qpm_gen:.1f} queries/min')
print(f'Throughput (generated + executed): {qpm_full:.1f} queries/min')
create_stats_report(valid_queries,
invalid_queries,
actually_executed_queries,
qpm_gen=qpm_gen,
qpm_full=qpm_full,
wall_seconds=loop_elapsed)
if __name__ == '__main__':
main()