Files
2026-06-24 13:47:14 +02:00

372 lines
11 KiB
Python

#!/usr/bin/env python3
import argparse
from concurrent.futures import ThreadPoolExecutor
import os
import shutil
import subprocess
import sys
import tempfile
import threading
import time
from typing import Callable, Dict, List, Sequence, Tuple
import sqlglot
from sqlglot.tokenizer_core import Token
def tokenize(query: str) -> List[Token]:
return list(sqlglot.Tokenizer().tokenize(query))
def token_count(query: str) -> int:
"""Returns the token count of a query using sqlglot's tokenizer."""
try:
return len(tokenize(query))
except Exception:
return len(query.split())
def _ttype(tok: Token) -> str:
"""Returns the token_type of a Token as a string."""
return tok.token_type.name if hasattr(tok.token_type, 'name') else str(
tok.token_type)
def assemble_tokens(source: str, tokens: List[Token],
kept: Sequence[int]) -> str:
"""Rebuilds a SQL string from a subset of tokens by deleting only the dropped spans."""
keep = set(kept)
if not keep:
return ''
drops = sorted((tokens[i].start, tokens[i].end + 1)
for i in range(len(tokens))
if i not in keep)
out: List[str] = []
pos = 0
for (s, e) in drops:
out.append(source[pos:s])
# Edge Case Handling: If dropping an operator causes two tokens to merge, insert a space in place of the dropped operator.
left = source[s - 1] if s > 0 else ' '
right = source[e] if e < len(source) else ' '
if not left.isspace() and not right.isspace():
out.append(' ')
pos = e
out.append(source[pos:])
return ''.join(out)
class Oracle:
def __init__(self, test_script: str, jobs, timeout) -> None:
# Absolute path to the oracle script (e.g. `/reducer/queries/query1/script.sh`)
self.test_script = os.path.abspath(test_script)
# Absolute path to the directory of the oracle script (e.g. `/reducer/queries/query1`)
self.script_dir = os.path.dirname(self.test_script)
# Maximum duration in seconds allowed for the oracle subprocess before it is forcibly terminated
self.timeout = timeout
# Number of parallel worker threads to use for executing the oracle checks
self.jobs = max(1, jobs)
# Creates a unique temporary directory to store candidate query files during testing, e.g. `/tmp/reducer-12345`
self.workdir = tempfile.mkdtemp(prefix='reducer-')
# Dictionary to memoize results (query string -> boolean) and avoid redundant subprocess calls
self._cache: Dict[str, bool] = {}
# Lock to guarantee mutex for the cache and statistics counters
self._cache_lock = threading.Lock()
# Separate lock specifically for generating unique file IDs to prevent filename collisions in parallel execution
self._id_lock = threading.Lock()
# Atomic counter used to generate sequential, unique filenames for temporary candidate files
self._counter = 0
# Counter tracking the total number of actual subprocess executions performed (excluding cache hits)
self.calls = 0
# Counter tracking how many times a result was retrieved directly from the cache instead of running the script
self.cache_hits = 0
# ThreadPoolExecutor for parallel processing if jobs > 1, otherwise None to force sequential execution
self._pool = ThreadPoolExecutor(max_workers=self.jobs) if self.jobs > 1 \
else None
def _next_path(self) -> str:
"""Generates a unique filepath for the next candidate query."""
with self._id_lock:
self._counter += 1
return os.path.join(self.workdir, f'cand-{self._counter}.sql')
def _run_uncached(self, candidate: str) -> bool:
"""Runs the oracle script to evaluate the candidate query. This is called only when the candidate is not in the cache."""
path = self._next_path()
try:
with open(path, 'w') as f:
f.write(candidate)
env = os.environ.copy()
env['TEST_CASE_LOCATION'] = path
try:
result = subprocess.run(
['bash', self.test_script],
cwd=self.script_dir,
env=env,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
timeout=self.timeout,
)
except subprocess.TimeoutExpired:
return False
except Exception:
return False
return result.returncode == 0
finally:
try:
os.remove(path)
except OSError:
pass
def check(self, candidate: str) -> bool:
"""Checks if the candidate query is already in the cache. Returns the cached result if it exists, otherwise runs the bash script."""
if not candidate.strip():
return False
with self._cache_lock:
if candidate in self._cache:
self.cache_hits += 1
return self._cache[candidate]
verdict = self._run_uncached(candidate)
with self._cache_lock:
self.calls += 1
self._cache[candidate] = verdict
return verdict
def check_many(self, candidates: List[str]) -> List[bool]:
"""Evaluate several candidates, using the `ThreadPoolExecutor` when `jobs > 1`."""
if self._pool is None or len(candidates) <= 1:
return [self.check(c) for c in candidates]
return list(self._pool.map(self.check, candidates))
def cleanup(self) -> None:
"""Shuts down the ThreadPoolExecutor (if exists) and deletes the workdir."""
if self._pool is not None:
self._pool.shutdown(wait=True)
shutil.rmtree(self.workdir, ignore_errors=True)
BuildFn = Callable[[List[int]], str]
def ddmin(n_items: int, build: BuildFn, oracle: Oracle) -> List[int]:
kept: List[int] = list(range(n_items))
n = 2
while len(kept) >= 2:
chunk = max(1, len(kept) // n)
chunks = [kept[i:i + chunk] for i in range(0, len(kept), chunk)]
# Check if each individual chunk still reproduces the bug.
subset_results = oracle.check_many([build(c) for c in chunks])
hit = next((i for i, ok in enumerate(subset_results) if ok), None)
if hit is not None:
kept = chunks[hit]
n = 2
continue
# Check if each individual complement still reproduces the bug.
complements = []
for i in range(len(chunks)):
comp: List[int] = []
for j, c in enumerate(chunks):
if j != i:
comp.extend(c)
complements.append(comp)
comp_results = oracle.check_many(
[build(c) if c else '' for c in complements])
hit = next(
(i for i, ok in enumerate(comp_results) if ok and complements[i]), None)
if hit is not None:
kept = complements[hit]
n = max(n - 1, 2)
continue
if n >= len(kept):
break
n = min(n * 2, len(kept))
return kept
def reduce_lines(text: str, oracle: Oracle) -> str:
"""Applies delta debugging at line-level granularity."""
lines = text.split('\n')
if len(lines) <= 1:
return text
def build(kept: List[int]) -> str:
return '\n'.join(lines[i] for i in kept)
kept = ddmin(len(lines), build, oracle)
return build(kept)
def reduce_tokens(text: str, oracle: Oracle) -> str:
"""Applies delta debugging at token-level granularity."""
try:
tokens = tokenize(text)
except Exception:
return text
if len(tokens) <= 1:
return text
def build(kept: List[int]) -> str:
return assemble_tokens(text, tokens, kept)
kept = ddmin(len(tokens), build, oracle)
return build(kept)
def _paren_pairs(tokens: List[Token]) -> List[Tuple[int, int]]:
"""Return (open_idx, close_idx) for every balanced parenthesis pair."""
stack: List[int] = []
pairs: List[Tuple[int, int]] = []
for i, tok in enumerate(tokens):
k = _ttype(tok)
if k == 'L_PAREN':
stack.append(i)
elif k == 'R_PAREN' and stack:
pairs.append((stack.pop(), i))
return pairs
def reduce_brackets(text: str, oracle: Oracle) -> str:
"""Removes balanced parenthesis groups (and optionally preceding tokens) that pass an oracle check.\n
In other words, this pass removes things like `func(args)` or `(x + y)` if they are redundant."""
try:
tokens = tokenize(text)
except Exception:
return text
pairs = _paren_pairs(tokens)
if not pairs:
return text
full = list(range(len(tokens)))
ranges: List[Tuple[int, int]] = []
for (lo, hi) in pairs:
ranges.append((lo, hi))
if lo > 0:
ranges.append((lo - 1, hi))
def without(lo: int, hi: int) -> str:
keep = full[:lo] + full[hi + 1:]
return assemble_tokens(text, tokens, keep)
results = oracle.check_many([without(lo, hi) for (lo, hi) in ranges])
passing = [r for r, ok in zip(ranges, results) if ok]
if not passing:
return text
passing.sort(key=lambda r: (r[1] - r[0]), reverse=True)
chosen: List[Tuple[int, int]] = []
for (lo, hi) in passing:
if all(hi < c[0] or lo > c[1] for c in chosen):
chosen.append((lo, hi))
removed = set()
for (lo, hi) in chosen:
removed.update(range(lo, hi + 1))
keep = [i for i in full if i not in removed]
combined = assemble_tokens(text, tokens, keep)
if oracle.check(combined):
return combined
lo, hi = passing[0]
return without(lo, hi)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description='SQL test-case reducer')
p.add_argument('--query',
required=True,
help='path to the .sql file to minimise')
p.add_argument('--test',
required=True,
help='path to the oracle shell script')
p.add_argument(
'--output',
help='if given, write the reduced query here and leave --query '
'untouched; if omitted, --query is modified in place')
p.add_argument(
'--jobs',
type=int,
default=os.cpu_count() if os.cpu_count() is not None else 1,
help='number of concurrent oracle processes (default: os.cpu_count())')
p.add_argument('--timeout',
type=float,
default=15.0,
help='per-oracle-call timeout in seconds')
return p.parse_args()
def run() -> int:
args = parse_args()
with open(args.query, 'r') as f:
original_text = f.read()
oracle = Oracle(args.test, jobs=args.jobs, timeout=args.timeout)
try:
if not oracle.check(original_text):
sys.stderr.write(
'[reducer] original query does not trigger the bug; aborting.\n')
return 2
start_tokens = token_count(original_text)
start_time = time.time()
sys.stderr.write(f'[reducer] starting with {start_tokens} tokens '
f'(jobs={oracle.jobs})\n')
text = original_text
prev = None
pass_idx = 0
while prev != text:
prev = text
pass_idx += 1
text = reduce_lines(text, oracle)
text = reduce_brackets(text, oracle)
text = reduce_tokens(text, oracle)
sys.stderr.write(
f'[reducer] pass {pass_idx}: {token_count(text)} tokens\n')
dest = args.output if args.output else args.query
with open(dest, 'w') as f:
f.write(text)
elapsed = time.time() - start_time
final_tokens = token_count(text)
pct = 100.0 * (1.0 - final_tokens / max(start_tokens, 1))
sys.stderr.write(
f'[reducer] done: {start_tokens} -> {final_tokens} tokens '
f'({pct:.1f}% reduction) in {elapsed:.1f}s, '
f'{oracle.calls} oracle calls (+{oracle.cache_hits} cached); '
f'wrote {dest}\n')
return 0
finally:
oracle.cleanup()
if __name__ == '__main__':
sys.exit(run())