#!/usr/bin/env python3 import argparse from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait import datetime import multiprocessing import os import random import signal import time from mutator import mutate from mutator_extra_statements import ( mut_prepend_random_meta_command, mut_prepend_random_pragma, ) from oracle import check from runner import reset_coverage from stats import create_stats_report class _MutateTimeout(Exception): pass def _timeout_handler(signum, frame): raise _MutateTimeout() _HAS_SIGALRM = hasattr(signal, 'SIGALRM') and hasattr(signal, 'setitimer') if _HAS_SIGALRM: signal.signal(signal.SIGALRM, _timeout_handler) def mutate_safe(query: str, timeout: float = 0.5) -> str: """ Runs mutate(query) with a real, in-process timeout (no threads). Returns the mutated query if successful, or the original on timeout/error. """ if not _HAS_SIGALRM: try: return mutate(query) except Exception: return query signal.setitimer(signal.ITIMER_REAL, timeout) try: result = mutate(query) except _MutateTimeout: return query except Exception: return query finally: signal.setitimer(signal.ITIMER_REAL, 0) return result if isinstance(result, str) else query def _load_seeds(seeds_dir: str) -> list[str]: """Load and deduplicate seeds from disk + additional.sql""" seeds: list[str] = [] seen: set[str] = set() if os.path.isdir(seeds_dir): for fname in os.listdir(seeds_dir): try: with open(os.path.join(seeds_dir, fname), 'r', encoding='latin-1') as f: content = f.read() except OSError: continue if content and content not in seen: seen.add(content) seeds.append(content) try: with open("additional.sql", "r", encoding='latin-1') as fh: raw = fh.read() for q in raw.split('---'): q = q.strip() if q and q not in seen: seen.add(q) seeds.append(q) except OSError: pass return seeds def _load_seeds_disk_only(seeds_dir: str) -> list[str]: """Load seeds from disk""" seeds: list[str] = [] if os.path.isdir(seeds_dir): for fname in os.listdir(seeds_dir): try: with open(os.path.join(seeds_dir, fname), 'r', encoding='latin-1') as f: content = f.read() except OSError: continue seeds.append(content) return seeds def main(): parser = argparse.ArgumentParser() parser.add_argument('--seeds', default='/home/test/seeds', help='Directory with seed .sql files') parser.add_argument('--buggy', default='/home/test/sqlite3-src/build/sqlite3', help='Buggy SQLite binary') parser.add_argument('--reference', default='/usr/bin/sqlite3', help='Reference SQLite binary') parser.add_argument('--count', type=int, default=10000, help='Number of queries to generate') parser.add_argument('--mutate-timeout', type=float, default=0.5, help='Per-mutation timeout in seconds (default 0.5)') parser.add_argument( '--max-query-length', type=int, default=100000, help='Hard cap on a single query in characters (default 100000). Queries ' 'that grow past this after mutation are reset back to the original ' 'seed for that iteration.') parser.add_argument('--workers', type=int, default=multiprocessing.cpu_count(), help="""Number of parallel check() workers.""") parser.add_argument( '--validate-seeds', action='store_true', default=False, help= 'Run the upfront seed-validation pass (~2N subprocess calls before the main loop). Off by default.' ) parser.add_argument( '--run-baseline', action='store_true', default=False, help= 'Run the seeds inside /home/test/seeds only, without applying any mutation.' ) args = parser.parse_args() # # ################################################################################################################# if args.run_baseline: reset_coverage() seeds = _load_seeds_disk_only(args.seeds) queries: list[str] = list(seeds) # pyright: ignore[reportRedeclaration] print(f'{datetime.datetime.now()}: [INFO] Loaded {len(queries)} seeds.') print( f'{datetime.datetime.now()}: [INFO] Running {len(queries)} iterations...') invalid_queries = 0 loop_start = time.perf_counter() for query in queries: is_ok = check(args.buggy, args.reference, query, False) if is_ok == -1: invalid_queries += 1 loop_end = time.perf_counter() loop_elapsed = loop_end - loop_start valid_queries = len(queries) - invalid_queries print() create_stats_report(valid_queries, invalid_queries, queries, qpm_gen=-1, qpm_full=-1, wall_seconds=loop_elapsed) else: seeds = _load_seeds(args.seeds) queries: list[str] = list(seeds) if args.validate_seeds: print(f'{datetime.datetime.now()}: [INFO] Validating seeds...') queries = [] for seed in seeds: if check(args.buggy, args.reference, seed, save_bugs=False) != -1: queries.append(seed) variants: list[str] = [] for seed in seeds: modified = mut_prepend_random_pragma( mut_prepend_random_pragma( mut_prepend_random_pragma(mut_prepend_random_meta_command(seed)))) variants.append(modified) queries.extend(variants) print(f'{datetime.datetime.now()}: [INFO] Loaded {len(queries)} seeds.') print( f'{datetime.datetime.now()}: [INFO] Running {args.count} iterations with {args.workers} workers...' ) reset_coverage() invalid_queries = 0 actually_executed_queries: list[str] = [] count = int(args.count) phase1_end = min(300, count) phase2_end = max(phase1_end, count - 1000) gen_time_total = 0.0 loop_start = time.perf_counter() pending: dict[Future[int], tuple[int, str, str, bool]] = {} pool = ThreadPoolExecutor(max_workers=args.workers) submitted = 0 def _submit_next(): """Mutate one query and submit a check() to the worker pool""" nonlocal gen_time_total, submitted idx = random.randint(0, len(queries) - 1) original = queries[idx] with_flag = (submitted < phase1_end) or (submitted >= phase2_end) start = time.perf_counter() query = mutate_safe(original, timeout=args.mutate_timeout) if len(query) > args.max_query_length: query = original end = time.perf_counter() gen_time_total += end - start fut = pool.submit(check, args.buggy, args.reference, query, with_flag) pending[fut] = (idx, query, original, with_flag) submitted += 1 def _drain_one(fut: Future): """Collect one completed future and update state""" nonlocal invalid_queries idx, query, original, _with_flag = pending.pop(fut) try: is_ok = fut.result() except Exception: is_ok = -1 actually_executed_queries.append(query) if is_ok == -1: invalid_queries += 1 elif query != original and idx < len( queries) and queries[idx] == original: queries[idx] = query n = len(actually_executed_queries) if n % 100 == 0: print(f'{datetime.datetime.now()}: [INFO] Executed {n} queries...') for _ in range(min(int(args.workers), count)): _submit_next() while len(actually_executed_queries) < count: done, _ = wait(pending.keys(), return_when=FIRST_COMPLETED) for fut in done: _drain_one(fut) if submitted < count: _submit_next() if len(actually_executed_queries) >= count: break pool.shutdown(wait=True) loop_end = time.perf_counter() loop_elapsed = loop_end - loop_start valid_queries = len(actually_executed_queries) - invalid_queries qpm_gen = (count / gen_time_total * 60.0) if gen_time_total > 0 else 0.0 qpm_full = (count / loop_elapsed * 60.0) if loop_elapsed > 0 else 0.0 print() print(f'Wall clock: {loop_elapsed:.2f}s for {count} iterations') print(f'Generation-only time: {gen_time_total:.2f}s ' f'(≈ {gen_time_total / count * 1000:.2f} ms/query)') print(f'Throughput (generated): {qpm_gen:.1f} queries/min') print(f'Throughput (generated + executed): {qpm_full:.1f} queries/min') create_stats_report(valid_queries, invalid_queries, actually_executed_queries, qpm_gen=qpm_gen, qpm_full=qpm_full, wall_seconds=loop_elapsed) if __name__ == '__main__': main()