318 lines
8.8 KiB
Python
318 lines
8.8 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
|
|
import datetime
|
|
import multiprocessing
|
|
import os
|
|
import random
|
|
import signal
|
|
import time
|
|
|
|
from mutator import mutate
|
|
from mutator_extra_statements import (
|
|
mut_prepend_random_meta_command,
|
|
mut_prepend_random_pragma,
|
|
)
|
|
from oracle import check
|
|
from runner import reset_coverage
|
|
from stats import create_stats_report
|
|
|
|
|
|
class _MutateTimeout(Exception):
|
|
pass
|
|
|
|
|
|
def _timeout_handler(signum, frame):
|
|
raise _MutateTimeout()
|
|
|
|
|
|
_HAS_SIGALRM = hasattr(signal, 'SIGALRM') and hasattr(signal, 'setitimer')
|
|
|
|
if _HAS_SIGALRM:
|
|
signal.signal(signal.SIGALRM, _timeout_handler)
|
|
|
|
|
|
def mutate_safe(query: str, timeout: float = 0.5) -> str:
|
|
"""
|
|
Runs mutate(query) with a real, in-process timeout (no threads).
|
|
Returns the mutated query if successful, or the original on timeout/error.
|
|
"""
|
|
|
|
if not _HAS_SIGALRM:
|
|
try:
|
|
return mutate(query)
|
|
except Exception:
|
|
return query
|
|
|
|
signal.setitimer(signal.ITIMER_REAL, timeout)
|
|
try:
|
|
result = mutate(query)
|
|
except _MutateTimeout:
|
|
return query
|
|
except Exception:
|
|
return query
|
|
finally:
|
|
signal.setitimer(signal.ITIMER_REAL, 0)
|
|
|
|
return result if isinstance(result, str) else query
|
|
|
|
|
|
def _load_seeds(seeds_dir: str) -> list[str]:
|
|
"""Load and deduplicate seeds from disk + additional.sql"""
|
|
seeds: list[str] = []
|
|
seen: set[str] = set()
|
|
|
|
if os.path.isdir(seeds_dir):
|
|
for fname in os.listdir(seeds_dir):
|
|
try:
|
|
with open(os.path.join(seeds_dir, fname), 'r', encoding='latin-1') as f:
|
|
content = f.read()
|
|
except OSError:
|
|
continue
|
|
if content and content not in seen:
|
|
seen.add(content)
|
|
seeds.append(content)
|
|
|
|
try:
|
|
with open("additional.sql", "r", encoding='latin-1') as fh:
|
|
raw = fh.read()
|
|
for q in raw.split('---'):
|
|
q = q.strip()
|
|
if q and q not in seen:
|
|
seen.add(q)
|
|
seeds.append(q)
|
|
except OSError:
|
|
pass
|
|
|
|
return seeds
|
|
|
|
|
|
def _load_seeds_disk_only(seeds_dir: str) -> list[str]:
|
|
"""Load seeds from disk"""
|
|
seeds: list[str] = []
|
|
|
|
if os.path.isdir(seeds_dir):
|
|
for fname in os.listdir(seeds_dir):
|
|
try:
|
|
with open(os.path.join(seeds_dir, fname), 'r', encoding='latin-1') as f:
|
|
content = f.read()
|
|
except OSError:
|
|
continue
|
|
seeds.append(content)
|
|
|
|
return seeds
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('--seeds',
|
|
default='/home/test/seeds',
|
|
help='Directory with seed .sql files')
|
|
|
|
parser.add_argument('--buggy',
|
|
default='/home/test/sqlite3-src/build/sqlite3',
|
|
help='Buggy SQLite binary')
|
|
|
|
parser.add_argument('--reference',
|
|
default='/usr/bin/sqlite3',
|
|
help='Reference SQLite binary')
|
|
|
|
parser.add_argument('--count',
|
|
type=int,
|
|
default=10000,
|
|
help='Number of queries to generate')
|
|
|
|
parser.add_argument('--mutate-timeout',
|
|
type=float,
|
|
default=0.5,
|
|
help='Per-mutation timeout in seconds (default 0.5)')
|
|
|
|
parser.add_argument(
|
|
'--max-query-length',
|
|
type=int,
|
|
default=100000,
|
|
help='Hard cap on a single query in characters (default 100000). Queries '
|
|
'that grow past this after mutation are reset back to the original '
|
|
'seed for that iteration.')
|
|
|
|
parser.add_argument('--workers',
|
|
type=int,
|
|
default=multiprocessing.cpu_count(),
|
|
help="""Number of parallel check() workers.""")
|
|
|
|
parser.add_argument(
|
|
'--validate-seeds',
|
|
action='store_true',
|
|
default=False,
|
|
help=
|
|
'Run the upfront seed-validation pass (~2N subprocess calls before the main loop). Off by default.'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--run-baseline',
|
|
action='store_true',
|
|
default=False,
|
|
help=
|
|
'Run the seeds inside /home/test/seeds only, without applying any mutation.'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# # #################################################################################################################
|
|
|
|
if args.run_baseline:
|
|
|
|
reset_coverage()
|
|
|
|
seeds = _load_seeds_disk_only(args.seeds)
|
|
|
|
queries: list[str] = list(seeds) # pyright: ignore[reportRedeclaration]
|
|
|
|
print(f'{datetime.datetime.now()}: [INFO] Loaded {len(queries)} seeds.')
|
|
print(
|
|
f'{datetime.datetime.now()}: [INFO] Running {len(queries)} iterations...')
|
|
|
|
invalid_queries = 0
|
|
|
|
loop_start = time.perf_counter()
|
|
|
|
for query in queries:
|
|
is_ok = check(args.buggy, args.reference, query, False)
|
|
if is_ok == -1:
|
|
invalid_queries += 1
|
|
|
|
loop_end = time.perf_counter()
|
|
|
|
loop_elapsed = loop_end - loop_start
|
|
|
|
valid_queries = len(queries) - invalid_queries
|
|
|
|
print()
|
|
create_stats_report(valid_queries,
|
|
invalid_queries,
|
|
queries,
|
|
qpm_gen=-1,
|
|
qpm_full=-1,
|
|
wall_seconds=loop_elapsed)
|
|
else:
|
|
|
|
seeds = _load_seeds(args.seeds)
|
|
|
|
queries: list[str] = list(seeds)
|
|
|
|
if args.validate_seeds:
|
|
print(f'{datetime.datetime.now()}: [INFO] Validating seeds...')
|
|
queries = []
|
|
for seed in seeds:
|
|
if check(args.buggy, args.reference, seed, save_bugs=False) != -1:
|
|
queries.append(seed)
|
|
|
|
variants: list[str] = []
|
|
for seed in seeds:
|
|
modified = mut_prepend_random_pragma(
|
|
mut_prepend_random_pragma(
|
|
mut_prepend_random_pragma(mut_prepend_random_meta_command(seed))))
|
|
variants.append(modified)
|
|
queries.extend(variants)
|
|
|
|
print(f'{datetime.datetime.now()}: [INFO] Loaded {len(queries)} seeds.')
|
|
print(
|
|
f'{datetime.datetime.now()}: [INFO] Running {args.count} iterations with {args.workers} workers...'
|
|
)
|
|
|
|
reset_coverage()
|
|
invalid_queries = 0
|
|
actually_executed_queries: list[str] = []
|
|
|
|
count = int(args.count)
|
|
phase1_end = min(300, count)
|
|
phase2_end = max(phase1_end, count - 1000)
|
|
|
|
gen_time_total = 0.0
|
|
|
|
loop_start = time.perf_counter()
|
|
|
|
pending: dict[Future[int], tuple[int, str, str, bool]] = {}
|
|
pool = ThreadPoolExecutor(max_workers=args.workers)
|
|
|
|
submitted = 0
|
|
|
|
def _submit_next():
|
|
"""Mutate one query and submit a check() to the worker pool"""
|
|
nonlocal gen_time_total, submitted
|
|
idx = random.randint(0, len(queries) - 1)
|
|
original = queries[idx]
|
|
with_flag = (submitted < phase1_end) or (submitted >= phase2_end)
|
|
|
|
start = time.perf_counter()
|
|
query = mutate_safe(original, timeout=args.mutate_timeout)
|
|
if len(query) > args.max_query_length:
|
|
query = original
|
|
end = time.perf_counter()
|
|
gen_time_total += end - start
|
|
|
|
fut = pool.submit(check, args.buggy, args.reference, query, with_flag)
|
|
pending[fut] = (idx, query, original, with_flag)
|
|
submitted += 1
|
|
|
|
def _drain_one(fut: Future):
|
|
"""Collect one completed future and update state"""
|
|
nonlocal invalid_queries
|
|
idx, query, original, _with_flag = pending.pop(fut)
|
|
try:
|
|
is_ok = fut.result()
|
|
except Exception:
|
|
is_ok = -1
|
|
actually_executed_queries.append(query)
|
|
if is_ok == -1:
|
|
invalid_queries += 1
|
|
elif query != original and idx < len(
|
|
queries) and queries[idx] == original:
|
|
queries[idx] = query
|
|
|
|
n = len(actually_executed_queries)
|
|
if n % 100 == 0:
|
|
print(f'{datetime.datetime.now()}: [INFO] Executed {n} queries...')
|
|
|
|
for _ in range(min(int(args.workers), count)):
|
|
_submit_next()
|
|
|
|
while len(actually_executed_queries) < count:
|
|
done, _ = wait(pending.keys(), return_when=FIRST_COMPLETED)
|
|
for fut in done:
|
|
_drain_one(fut)
|
|
if submitted < count:
|
|
_submit_next()
|
|
if len(actually_executed_queries) >= count:
|
|
break
|
|
|
|
pool.shutdown(wait=True)
|
|
|
|
loop_end = time.perf_counter()
|
|
loop_elapsed = loop_end - loop_start
|
|
|
|
valid_queries = len(actually_executed_queries) - invalid_queries
|
|
|
|
qpm_gen = (count / gen_time_total * 60.0) if gen_time_total > 0 else 0.0
|
|
qpm_full = (count / loop_elapsed * 60.0) if loop_elapsed > 0 else 0.0
|
|
|
|
print()
|
|
print(f'Wall clock: {loop_elapsed:.2f}s for {count} iterations')
|
|
print(f'Generation-only time: {gen_time_total:.2f}s '
|
|
f'(≈ {gen_time_total / count * 1000:.2f} ms/query)')
|
|
print(f'Throughput (generated): {qpm_gen:.1f} queries/min')
|
|
print(f'Throughput (generated + executed): {qpm_full:.1f} queries/min')
|
|
|
|
create_stats_report(valid_queries,
|
|
invalid_queries,
|
|
actually_executed_queries,
|
|
qpm_gen=qpm_gen,
|
|
qpm_full=qpm_full,
|
|
wall_seconds=loop_elapsed)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|