Files
ai-rfm/scripts/precommit_performance_gate.py
demian3b b8c440d654 chore: stop tracking *.pt/*.sdf; purge from history; align hooks and docs.
git-filter-repo removed blobs; origin must be re-added. Pre-commit refreshes
BEST_PRACTICE.json and trajectory manifest only (checkpoints stay local).

Made-with: Cursor
2026-04-17 14:01:06 +09:00

190 lines
6.5 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import json
import os
import shutil
import subprocess
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
BEST = ROOT / "BEST_PRACTICE.json"
LATEST = ROOT / "reports" / "latest_eval.json"
LATEST_MODEL = ROOT / "artifacts" / "latest_eval_best_model.pt"
BEST_MODEL = ROOT / "artifacts" / "best_model.pt"
def git(*args: str) -> str:
return subprocess.check_output(["git", *args], cwd=ROOT, text=True).strip()
def staged_files() -> set[str]:
out = git("diff", "--cached", "--name-only")
return {line for line in out.splitlines() if line}
def changed_files_since(base_ref: str, head_ref: str = "HEAD") -> set[str]:
out = git("diff", "--name-only", f"{base_ref}..{head_ref}")
return {line for line in out.splitlines() if line}
def read_json(path: Path) -> dict:
if not path.exists():
raise FileNotFoundError(path)
return json.loads(path.read_text())
def write_json(path: Path, payload: dict) -> None:
path.write_text(json.dumps(payload, indent=2) + "\n")
def read_head_best() -> float:
try:
raw = git("show", "HEAD:BEST_PRACTICE.json")
except subprocess.CalledProcessError:
return float("inf")
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
return float("inf")
return float(parsed.get("best_mean_rmsd_100", float("inf")))
def fail(msg: str) -> int:
print(f"[pre-commit] {msg}", file=sys.stderr)
return 1
def staged_train_py_text() -> str:
return subprocess.check_output(["git", "show", ":train.py"], cwd=ROOT, text=True)
def enforce_random_time_requirement(train_src: str) -> None:
forbidden = ["t_zero_prob", "fixed_time", "torch.zeros((), device=cpu, dtype=torch.float32)"]
for token in forbidden:
if token in train_src:
raise RuntimeError(
f"Flow-matching constraint violation: found forbidden token '{token}'. "
"Training time must remain random."
)
def current_branch() -> str:
return git("rev-parse", "--abbrev-ref", "HEAD")
def is_main_branch() -> bool:
return current_branch() == "main"
def get_orig_head() -> str | None:
try:
return git("rev-parse", "--verify", "ORIG_HEAD")
except subprocess.CalledProcessError:
return None
def read_best_from_ref(ref: str) -> float:
try:
raw = git("show", f"{ref}:BEST_PRACTICE.json")
except subprocess.CalledProcessError:
return float("inf")
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
return float("inf")
return float(parsed.get("best_mean_rmsd_100", float("inf")))
def run_main_gate(previous_best: float, mode: str) -> int:
prefix = "[post-merge]" if mode == "post-merge" else "[pre-commit]"
latest = read_json(LATEST)
latest_rmsd = float(latest["mean_rmsd_100"])
latest_runs = int(latest["num_runs"])
if latest_runs != 100:
return fail(f"{prefix} num_runs must be 100, got {latest_runs}.")
if not (latest_rmsd < previous_best):
return fail(
f"{prefix} No improvement: latest={latest_rmsd:.6f}, previous_best={previous_best:.6f}. "
"main integration with train.py requires strict improvement."
)
best_report = {
"best_mean_rmsd_100": latest_rmsd,
"num_runs": latest_runs,
"timestamp_utc": latest.get("timestamp_utc", ""),
"command": latest.get("command", ""),
"notes": "Auto-updated by pre-commit from reports/latest_eval.json.",
"updated_by_commit": "pending",
"best_train_mse": latest.get("best_train_mse"),
}
write_json(BEST, best_report)
shutil.copy2(LATEST_MODEL, BEST_MODEL)
subprocess.check_call([sys.executable, str(ROOT / "scripts" / "update_best_artifacts.py")], cwd=ROOT)
if mode == "pre-commit":
subprocess.check_call(["git", "add", str(BEST)], cwd=ROOT)
subprocess.check_call(["git", "add", "reports/trajectories"], cwd=ROOT)
print(
f"{prefix} PASS: improved mean_rmsd_100 {previous_best:.6f} -> {latest_rmsd:.6f}; "
"BEST_PRACTICE.json refreshed; checkpoints (*.pt) and *.sdf stay local (gitignored, not staged).",
file=sys.stderr,
)
return 0
def main() -> int:
hook_stage = os.environ.get("PRE_COMMIT_HOOK_STAGE", "pre-commit")
if hook_stage == "post-merge":
if not is_main_branch():
return 0
orig_head = get_orig_head()
if not orig_head:
return 0
changed = changed_files_since(orig_head, "HEAD")
if "train.py" not in changed:
return 0
train_src = (ROOT / "train.py").read_text(encoding="utf-8")
try:
enforce_random_time_requirement(train_src)
except RuntimeError as e:
return fail(f"[post-merge] {e}")
if "reports/latest_eval.json" not in changed:
return fail("[post-merge] train.py merged to main: reports/latest_eval.json must be included.")
if "README.md" not in changed:
return fail("[post-merge] train.py merged to main: README.md attempt log update is required.")
if not LATEST_MODEL.exists():
return fail("[post-merge] Missing artifacts/latest_eval_best_model.pt from merged training run.")
previous_best = read_best_from_ref(orig_head)
return run_main_gate(previous_best=previous_best, mode="post-merge")
staged = staged_files()
if "train.py" not in staged:
return 0
try:
train_src = staged_train_py_text()
except subprocess.CalledProcessError:
train_src = (ROOT / "train.py").read_text(encoding="utf-8")
try:
enforce_random_time_requirement(train_src)
except RuntimeError as e:
return fail(str(e))
if not is_main_branch():
print(
"[pre-commit] train.py staged on a feature branch: "
"skipping mean_rmsd_100 gate (runs on main only at commit/merge time).",
file=sys.stderr,
)
return 0
if "reports/latest_eval.json" not in staged:
return fail("train.py changed: stage reports/latest_eval.json too.")
if "README.md" not in staged:
return fail("train.py changed: stage README.md with attempt log update.")
if not LATEST_MODEL.exists():
return fail("Missing artifacts/latest_eval_best_model.pt from training run.")
previous_best = read_head_best()
return run_main_gate(previous_best=previous_best, mode="pre-commit")
if __name__ == "__main__":
raise SystemExit(main())