#!/usr/bin/env python3 import argparse from concurrent.futures import ThreadPoolExecutor import os import shutil import subprocess import sys import tempfile import threading import time from typing import Callable, Dict, List, Sequence, Tuple import sqlglot from sqlglot.tokenizer_core import Token def tokenize(query: str) -> List[Token]: return list(sqlglot.Tokenizer().tokenize(query)) def token_count(query: str) -> int: """Returns the token count of a query using sqlglot's tokenizer.""" try: return len(tokenize(query)) except Exception: return len(query.split()) def _ttype(tok: Token) -> str: """Returns the token_type of a Token as a string.""" return tok.token_type.name if hasattr(tok.token_type, 'name') else str( tok.token_type) def assemble_tokens(source: str, tokens: List[Token], kept: Sequence[int]) -> str: """Rebuilds a SQL string from a subset of tokens by deleting only the dropped spans.""" keep = set(kept) if not keep: return '' drops = sorted((tokens[i].start, tokens[i].end + 1) for i in range(len(tokens)) if i not in keep) out: List[str] = [] pos = 0 for (s, e) in drops: out.append(source[pos:s]) # Edge Case Handling: If dropping an operator causes two tokens to merge, insert a space in place of the dropped operator. left = source[s - 1] if s > 0 else ' ' right = source[e] if e < len(source) else ' ' if not left.isspace() and not right.isspace(): out.append(' ') pos = e out.append(source[pos:]) return ''.join(out) class Oracle: def __init__(self, test_script: str, jobs, timeout) -> None: # Absolute path to the oracle script (e.g. `/reducer/queries/query1/script.sh`) self.test_script = os.path.abspath(test_script) # Absolute path to the directory of the oracle script (e.g. `/reducer/queries/query1`) self.script_dir = os.path.dirname(self.test_script) # Maximum duration in seconds allowed for the oracle subprocess before it is forcibly terminated self.timeout = timeout # Number of parallel worker threads to use for executing the oracle checks self.jobs = max(1, jobs) # Creates a unique temporary directory to store candidate query files during testing, e.g. `/tmp/reducer-12345` self.workdir = tempfile.mkdtemp(prefix='reducer-') # Dictionary to memoize results (query string -> boolean) and avoid redundant subprocess calls self._cache: Dict[str, bool] = {} # Lock to guarantee mutex for the cache and statistics counters self._cache_lock = threading.Lock() # Separate lock specifically for generating unique file IDs to prevent filename collisions in parallel execution self._id_lock = threading.Lock() # Atomic counter used to generate sequential, unique filenames for temporary candidate files self._counter = 0 # Counter tracking the total number of actual subprocess executions performed (excluding cache hits) self.calls = 0 # Counter tracking how many times a result was retrieved directly from the cache instead of running the script self.cache_hits = 0 # ThreadPoolExecutor for parallel processing if jobs > 1, otherwise None to force sequential execution self._pool = ThreadPoolExecutor(max_workers=self.jobs) if self.jobs > 1 \ else None def _next_path(self) -> str: """Generates a unique filepath for the next candidate query.""" with self._id_lock: self._counter += 1 return os.path.join(self.workdir, f'cand-{self._counter}.sql') def _run_uncached(self, candidate: str) -> bool: """Runs the oracle script to evaluate the candidate query. This is called only when the candidate is not in the cache.""" path = self._next_path() try: with open(path, 'w') as f: f.write(candidate) env = os.environ.copy() env['TEST_CASE_LOCATION'] = path try: result = subprocess.run( ['bash', self.test_script], cwd=self.script_dir, env=env, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=self.timeout, ) except subprocess.TimeoutExpired: return False except Exception: return False return result.returncode == 0 finally: try: os.remove(path) except OSError: pass def check(self, candidate: str) -> bool: """Checks if the candidate query is already in the cache. Returns the cached result if it exists, otherwise runs the bash script.""" if not candidate.strip(): return False with self._cache_lock: if candidate in self._cache: self.cache_hits += 1 return self._cache[candidate] verdict = self._run_uncached(candidate) with self._cache_lock: self.calls += 1 self._cache[candidate] = verdict return verdict def check_many(self, candidates: List[str]) -> List[bool]: """Evaluate several candidates, using the `ThreadPoolExecutor` when `jobs > 1`.""" if self._pool is None or len(candidates) <= 1: return [self.check(c) for c in candidates] return list(self._pool.map(self.check, candidates)) def cleanup(self) -> None: """Shuts down the ThreadPoolExecutor (if exists) and deletes the workdir.""" if self._pool is not None: self._pool.shutdown(wait=True) shutil.rmtree(self.workdir, ignore_errors=True) BuildFn = Callable[[List[int]], str] def ddmin(n_items: int, build: BuildFn, oracle: Oracle) -> List[int]: kept: List[int] = list(range(n_items)) n = 2 while len(kept) >= 2: chunk = max(1, len(kept) // n) chunks = [kept[i:i + chunk] for i in range(0, len(kept), chunk)] # Check if each individual chunk still reproduces the bug. subset_results = oracle.check_many([build(c) for c in chunks]) hit = next((i for i, ok in enumerate(subset_results) if ok), None) if hit is not None: kept = chunks[hit] n = 2 continue # Check if each individual complement still reproduces the bug. complements = [] for i in range(len(chunks)): comp: List[int] = [] for j, c in enumerate(chunks): if j != i: comp.extend(c) complements.append(comp) comp_results = oracle.check_many( [build(c) if c else '' for c in complements]) hit = next( (i for i, ok in enumerate(comp_results) if ok and complements[i]), None) if hit is not None: kept = complements[hit] n = max(n - 1, 2) continue if n >= len(kept): break n = min(n * 2, len(kept)) return kept def reduce_lines(text: str, oracle: Oracle) -> str: """Applies delta debugging at line-level granularity.""" lines = text.split('\n') if len(lines) <= 1: return text def build(kept: List[int]) -> str: return '\n'.join(lines[i] for i in kept) kept = ddmin(len(lines), build, oracle) return build(kept) def reduce_tokens(text: str, oracle: Oracle) -> str: """Applies delta debugging at token-level granularity.""" try: tokens = tokenize(text) except Exception: return text if len(tokens) <= 1: return text def build(kept: List[int]) -> str: return assemble_tokens(text, tokens, kept) kept = ddmin(len(tokens), build, oracle) return build(kept) def _paren_pairs(tokens: List[Token]) -> List[Tuple[int, int]]: """Return (open_idx, close_idx) for every balanced parenthesis pair.""" stack: List[int] = [] pairs: List[Tuple[int, int]] = [] for i, tok in enumerate(tokens): k = _ttype(tok) if k == 'L_PAREN': stack.append(i) elif k == 'R_PAREN' and stack: pairs.append((stack.pop(), i)) return pairs def reduce_brackets(text: str, oracle: Oracle) -> str: """Removes balanced parenthesis groups (and optionally preceding tokens) that pass an oracle check.\n In other words, this pass removes things like `func(args)` or `(x + y)` if they are redundant.""" try: tokens = tokenize(text) except Exception: return text pairs = _paren_pairs(tokens) if not pairs: return text full = list(range(len(tokens))) ranges: List[Tuple[int, int]] = [] for (lo, hi) in pairs: ranges.append((lo, hi)) if lo > 0: ranges.append((lo - 1, hi)) def without(lo: int, hi: int) -> str: keep = full[:lo] + full[hi + 1:] return assemble_tokens(text, tokens, keep) results = oracle.check_many([without(lo, hi) for (lo, hi) in ranges]) passing = [r for r, ok in zip(ranges, results) if ok] if not passing: return text passing.sort(key=lambda r: (r[1] - r[0]), reverse=True) chosen: List[Tuple[int, int]] = [] for (lo, hi) in passing: if all(hi < c[0] or lo > c[1] for c in chosen): chosen.append((lo, hi)) removed = set() for (lo, hi) in chosen: removed.update(range(lo, hi + 1)) keep = [i for i in full if i not in removed] combined = assemble_tokens(text, tokens, keep) if oracle.check(combined): return combined lo, hi = passing[0] return without(lo, hi) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description='SQL test-case reducer') p.add_argument('--query', required=True, help='path to the .sql file to minimise') p.add_argument('--test', required=True, help='path to the oracle shell script') p.add_argument( '--output', help='if given, write the reduced query here and leave --query ' 'untouched; if omitted, --query is modified in place') p.add_argument( '--jobs', type=int, default=os.cpu_count() if os.cpu_count() is not None else 1, help='number of concurrent oracle processes (default: os.cpu_count())') p.add_argument('--timeout', type=float, default=15.0, help='per-oracle-call timeout in seconds') return p.parse_args() def run() -> int: args = parse_args() with open(args.query, 'r') as f: original_text = f.read() oracle = Oracle(args.test, jobs=args.jobs, timeout=args.timeout) try: if not oracle.check(original_text): sys.stderr.write( '[reducer] original query does not trigger the bug; aborting.\n') return 2 start_tokens = token_count(original_text) start_time = time.time() sys.stderr.write(f'[reducer] starting with {start_tokens} tokens ' f'(jobs={oracle.jobs})\n') text = original_text prev = None pass_idx = 0 while prev != text: prev = text pass_idx += 1 text = reduce_lines(text, oracle) text = reduce_brackets(text, oracle) text = reduce_tokens(text, oracle) sys.stderr.write( f'[reducer] pass {pass_idx}: {token_count(text)} tokens\n') dest = args.output if args.output else args.query with open(dest, 'w') as f: f.write(text) elapsed = time.time() - start_time final_tokens = token_count(text) pct = 100.0 * (1.0 - final_tokens / max(start_tokens, 1)) sys.stderr.write( f'[reducer] done: {start_tokens} -> {final_tokens} tokens ' f'({pct:.1f}% reduction) in {elapsed:.1f}s, ' f'{oracle.calls} oracle calls (+{oracle.cache_hits} cached); ' f'wrote {dest}\n') return 0 finally: oracle.cleanup() if __name__ == '__main__': sys.exit(run())