372 lines
11 KiB
Python
372 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
import threading
|
|
import time
|
|
from typing import Callable, Dict, List, Sequence, Tuple
|
|
|
|
import sqlglot
|
|
from sqlglot.tokenizer_core import Token
|
|
|
|
|
|
def tokenize(query: str) -> List[Token]:
|
|
return list(sqlglot.Tokenizer().tokenize(query))
|
|
|
|
|
|
def token_count(query: str) -> int:
|
|
"""Returns the token count of a query using sqlglot's tokenizer."""
|
|
try:
|
|
return len(tokenize(query))
|
|
except Exception:
|
|
return len(query.split())
|
|
|
|
|
|
def _ttype(tok: Token) -> str:
|
|
"""Returns the token_type of a Token as a string."""
|
|
return tok.token_type.name if hasattr(tok.token_type, 'name') else str(
|
|
tok.token_type)
|
|
|
|
|
|
def assemble_tokens(source: str, tokens: List[Token],
|
|
kept: Sequence[int]) -> str:
|
|
"""Rebuilds a SQL string from a subset of tokens by deleting only the dropped spans."""
|
|
keep = set(kept)
|
|
if not keep:
|
|
return ''
|
|
|
|
drops = sorted((tokens[i].start, tokens[i].end + 1)
|
|
for i in range(len(tokens))
|
|
if i not in keep)
|
|
|
|
out: List[str] = []
|
|
|
|
pos = 0
|
|
|
|
for (s, e) in drops:
|
|
|
|
out.append(source[pos:s])
|
|
|
|
# Edge Case Handling: If dropping an operator causes two tokens to merge, insert a space in place of the dropped operator.
|
|
left = source[s - 1] if s > 0 else ' '
|
|
right = source[e] if e < len(source) else ' '
|
|
if not left.isspace() and not right.isspace():
|
|
out.append(' ')
|
|
|
|
pos = e
|
|
|
|
out.append(source[pos:])
|
|
|
|
return ''.join(out)
|
|
|
|
|
|
class Oracle:
|
|
|
|
def __init__(self, test_script: str, jobs, timeout) -> None:
|
|
|
|
# Absolute path to the oracle script (e.g. `/reducer/queries/query1/script.sh`)
|
|
self.test_script = os.path.abspath(test_script)
|
|
|
|
# Absolute path to the directory of the oracle script (e.g. `/reducer/queries/query1`)
|
|
self.script_dir = os.path.dirname(self.test_script)
|
|
|
|
# Maximum duration in seconds allowed for the oracle subprocess before it is forcibly terminated
|
|
self.timeout = timeout
|
|
|
|
# Number of parallel worker threads to use for executing the oracle checks
|
|
self.jobs = max(1, jobs)
|
|
|
|
# Creates a unique temporary directory to store candidate query files during testing, e.g. `/tmp/reducer-12345`
|
|
self.workdir = tempfile.mkdtemp(prefix='reducer-')
|
|
|
|
# Dictionary to memoize results (query string -> boolean) and avoid redundant subprocess calls
|
|
self._cache: Dict[str, bool] = {}
|
|
|
|
# Lock to guarantee mutex for the cache and statistics counters
|
|
self._cache_lock = threading.Lock()
|
|
|
|
# Separate lock specifically for generating unique file IDs to prevent filename collisions in parallel execution
|
|
self._id_lock = threading.Lock()
|
|
|
|
# Atomic counter used to generate sequential, unique filenames for temporary candidate files
|
|
self._counter = 0
|
|
|
|
# Counter tracking the total number of actual subprocess executions performed (excluding cache hits)
|
|
self.calls = 0
|
|
|
|
# Counter tracking how many times a result was retrieved directly from the cache instead of running the script
|
|
self.cache_hits = 0
|
|
|
|
# ThreadPoolExecutor for parallel processing if jobs > 1, otherwise None to force sequential execution
|
|
self._pool = ThreadPoolExecutor(max_workers=self.jobs) if self.jobs > 1 \
|
|
else None
|
|
|
|
def _next_path(self) -> str:
|
|
"""Generates a unique filepath for the next candidate query."""
|
|
with self._id_lock:
|
|
self._counter += 1
|
|
return os.path.join(self.workdir, f'cand-{self._counter}.sql')
|
|
|
|
def _run_uncached(self, candidate: str) -> bool:
|
|
"""Runs the oracle script to evaluate the candidate query. This is called only when the candidate is not in the cache."""
|
|
path = self._next_path()
|
|
try:
|
|
with open(path, 'w') as f:
|
|
f.write(candidate)
|
|
env = os.environ.copy()
|
|
env['TEST_CASE_LOCATION'] = path
|
|
try:
|
|
result = subprocess.run(
|
|
['bash', self.test_script],
|
|
cwd=self.script_dir,
|
|
env=env,
|
|
stdout=subprocess.DEVNULL,
|
|
stderr=subprocess.DEVNULL,
|
|
timeout=self.timeout,
|
|
)
|
|
except subprocess.TimeoutExpired:
|
|
return False
|
|
except Exception:
|
|
return False
|
|
return result.returncode == 0
|
|
finally:
|
|
try:
|
|
os.remove(path)
|
|
except OSError:
|
|
pass
|
|
|
|
def check(self, candidate: str) -> bool:
|
|
"""Checks if the candidate query is already in the cache. Returns the cached result if it exists, otherwise runs the bash script."""
|
|
if not candidate.strip():
|
|
return False
|
|
with self._cache_lock:
|
|
if candidate in self._cache:
|
|
self.cache_hits += 1
|
|
return self._cache[candidate]
|
|
verdict = self._run_uncached(candidate)
|
|
with self._cache_lock:
|
|
self.calls += 1
|
|
self._cache[candidate] = verdict
|
|
return verdict
|
|
|
|
def check_many(self, candidates: List[str]) -> List[bool]:
|
|
"""Evaluate several candidates, using the `ThreadPoolExecutor` when `jobs > 1`."""
|
|
if self._pool is None or len(candidates) <= 1:
|
|
return [self.check(c) for c in candidates]
|
|
return list(self._pool.map(self.check, candidates))
|
|
|
|
def cleanup(self) -> None:
|
|
"""Shuts down the ThreadPoolExecutor (if exists) and deletes the workdir."""
|
|
if self._pool is not None:
|
|
self._pool.shutdown(wait=True)
|
|
shutil.rmtree(self.workdir, ignore_errors=True)
|
|
|
|
|
|
BuildFn = Callable[[List[int]], str]
|
|
|
|
|
|
def ddmin(n_items: int, build: BuildFn, oracle: Oracle) -> List[int]:
|
|
kept: List[int] = list(range(n_items))
|
|
n = 2
|
|
while len(kept) >= 2:
|
|
chunk = max(1, len(kept) // n)
|
|
chunks = [kept[i:i + chunk] for i in range(0, len(kept), chunk)]
|
|
|
|
# Check if each individual chunk still reproduces the bug.
|
|
subset_results = oracle.check_many([build(c) for c in chunks])
|
|
hit = next((i for i, ok in enumerate(subset_results) if ok), None)
|
|
if hit is not None:
|
|
kept = chunks[hit]
|
|
n = 2
|
|
continue
|
|
|
|
# Check if each individual complement still reproduces the bug.
|
|
complements = []
|
|
for i in range(len(chunks)):
|
|
comp: List[int] = []
|
|
for j, c in enumerate(chunks):
|
|
if j != i:
|
|
comp.extend(c)
|
|
complements.append(comp)
|
|
comp_results = oracle.check_many(
|
|
[build(c) if c else '' for c in complements])
|
|
hit = next(
|
|
(i for i, ok in enumerate(comp_results) if ok and complements[i]), None)
|
|
if hit is not None:
|
|
kept = complements[hit]
|
|
n = max(n - 1, 2)
|
|
continue
|
|
|
|
if n >= len(kept):
|
|
break
|
|
n = min(n * 2, len(kept))
|
|
return kept
|
|
|
|
|
|
def reduce_lines(text: str, oracle: Oracle) -> str:
|
|
"""Applies delta debugging at line-level granularity."""
|
|
lines = text.split('\n')
|
|
if len(lines) <= 1:
|
|
return text
|
|
|
|
def build(kept: List[int]) -> str:
|
|
return '\n'.join(lines[i] for i in kept)
|
|
|
|
kept = ddmin(len(lines), build, oracle)
|
|
return build(kept)
|
|
|
|
|
|
def reduce_tokens(text: str, oracle: Oracle) -> str:
|
|
"""Applies delta debugging at token-level granularity."""
|
|
try:
|
|
tokens = tokenize(text)
|
|
except Exception:
|
|
return text
|
|
if len(tokens) <= 1:
|
|
return text
|
|
|
|
def build(kept: List[int]) -> str:
|
|
return assemble_tokens(text, tokens, kept)
|
|
|
|
kept = ddmin(len(tokens), build, oracle)
|
|
return build(kept)
|
|
|
|
|
|
def _paren_pairs(tokens: List[Token]) -> List[Tuple[int, int]]:
|
|
"""Return (open_idx, close_idx) for every balanced parenthesis pair."""
|
|
stack: List[int] = []
|
|
pairs: List[Tuple[int, int]] = []
|
|
for i, tok in enumerate(tokens):
|
|
k = _ttype(tok)
|
|
if k == 'L_PAREN':
|
|
stack.append(i)
|
|
elif k == 'R_PAREN' and stack:
|
|
pairs.append((stack.pop(), i))
|
|
return pairs
|
|
|
|
|
|
def reduce_brackets(text: str, oracle: Oracle) -> str:
|
|
"""Removes balanced parenthesis groups (and optionally preceding tokens) that pass an oracle check.\n
|
|
In other words, this pass removes things like `func(args)` or `(x + y)` if they are redundant."""
|
|
try:
|
|
tokens = tokenize(text)
|
|
except Exception:
|
|
return text
|
|
pairs = _paren_pairs(tokens)
|
|
if not pairs:
|
|
return text
|
|
full = list(range(len(tokens)))
|
|
|
|
ranges: List[Tuple[int, int]] = []
|
|
for (lo, hi) in pairs:
|
|
ranges.append((lo, hi))
|
|
if lo > 0:
|
|
ranges.append((lo - 1, hi))
|
|
|
|
def without(lo: int, hi: int) -> str:
|
|
keep = full[:lo] + full[hi + 1:]
|
|
return assemble_tokens(text, tokens, keep)
|
|
|
|
results = oracle.check_many([without(lo, hi) for (lo, hi) in ranges])
|
|
passing = [r for r, ok in zip(ranges, results) if ok]
|
|
if not passing:
|
|
return text
|
|
|
|
passing.sort(key=lambda r: (r[1] - r[0]), reverse=True)
|
|
chosen: List[Tuple[int, int]] = []
|
|
for (lo, hi) in passing:
|
|
if all(hi < c[0] or lo > c[1] for c in chosen):
|
|
chosen.append((lo, hi))
|
|
|
|
removed = set()
|
|
for (lo, hi) in chosen:
|
|
removed.update(range(lo, hi + 1))
|
|
keep = [i for i in full if i not in removed]
|
|
combined = assemble_tokens(text, tokens, keep)
|
|
|
|
if oracle.check(combined):
|
|
return combined
|
|
|
|
lo, hi = passing[0]
|
|
return without(lo, hi)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
p = argparse.ArgumentParser(description='SQL test-case reducer')
|
|
p.add_argument('--query',
|
|
required=True,
|
|
help='path to the .sql file to minimise')
|
|
p.add_argument('--test',
|
|
required=True,
|
|
help='path to the oracle shell script')
|
|
p.add_argument(
|
|
'--output',
|
|
help='if given, write the reduced query here and leave --query '
|
|
'untouched; if omitted, --query is modified in place')
|
|
p.add_argument(
|
|
'--jobs',
|
|
type=int,
|
|
default=os.cpu_count() if os.cpu_count() is not None else 1,
|
|
help='number of concurrent oracle processes (default: os.cpu_count())')
|
|
p.add_argument('--timeout',
|
|
type=float,
|
|
default=15.0,
|
|
help='per-oracle-call timeout in seconds')
|
|
return p.parse_args()
|
|
|
|
|
|
def run() -> int:
|
|
args = parse_args()
|
|
|
|
with open(args.query, 'r') as f:
|
|
original_text = f.read()
|
|
|
|
oracle = Oracle(args.test, jobs=args.jobs, timeout=args.timeout)
|
|
try:
|
|
if not oracle.check(original_text):
|
|
sys.stderr.write(
|
|
'[reducer] original query does not trigger the bug; aborting.\n')
|
|
return 2
|
|
|
|
start_tokens = token_count(original_text)
|
|
start_time = time.time()
|
|
sys.stderr.write(f'[reducer] starting with {start_tokens} tokens '
|
|
f'(jobs={oracle.jobs})\n')
|
|
|
|
text = original_text
|
|
prev = None
|
|
pass_idx = 0
|
|
while prev != text:
|
|
prev = text
|
|
pass_idx += 1
|
|
text = reduce_lines(text, oracle)
|
|
text = reduce_brackets(text, oracle)
|
|
text = reduce_tokens(text, oracle)
|
|
sys.stderr.write(
|
|
f'[reducer] pass {pass_idx}: {token_count(text)} tokens\n')
|
|
|
|
dest = args.output if args.output else args.query
|
|
with open(dest, 'w') as f:
|
|
f.write(text)
|
|
|
|
elapsed = time.time() - start_time
|
|
final_tokens = token_count(text)
|
|
pct = 100.0 * (1.0 - final_tokens / max(start_tokens, 1))
|
|
sys.stderr.write(
|
|
f'[reducer] done: {start_tokens} -> {final_tokens} tokens '
|
|
f'({pct:.1f}% reduction) in {elapsed:.1f}s, '
|
|
f'{oracle.calls} oracle calls (+{oracle.cache_hits} cached); '
|
|
f'wrote {dest}\n')
|
|
return 0
|
|
finally:
|
|
oracle.cleanup()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
sys.exit(run())
|