git-filter-repo removed blobs; origin must be re-added. Pre-commit refreshes BEST_PRACTICE.json and trajectory manifest only (checkpoints stay local). Made-with: Cursor
190 lines
6.5 KiB
Python
190 lines
6.5 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
BEST = ROOT / "BEST_PRACTICE.json"
|
|
LATEST = ROOT / "reports" / "latest_eval.json"
|
|
LATEST_MODEL = ROOT / "artifacts" / "latest_eval_best_model.pt"
|
|
BEST_MODEL = ROOT / "artifacts" / "best_model.pt"
|
|
|
|
|
|
def git(*args: str) -> str:
|
|
return subprocess.check_output(["git", *args], cwd=ROOT, text=True).strip()
|
|
|
|
|
|
def staged_files() -> set[str]:
|
|
out = git("diff", "--cached", "--name-only")
|
|
return {line for line in out.splitlines() if line}
|
|
|
|
|
|
def changed_files_since(base_ref: str, head_ref: str = "HEAD") -> set[str]:
|
|
out = git("diff", "--name-only", f"{base_ref}..{head_ref}")
|
|
return {line for line in out.splitlines() if line}
|
|
|
|
|
|
def read_json(path: Path) -> dict:
|
|
if not path.exists():
|
|
raise FileNotFoundError(path)
|
|
return json.loads(path.read_text())
|
|
|
|
|
|
def write_json(path: Path, payload: dict) -> None:
|
|
path.write_text(json.dumps(payload, indent=2) + "\n")
|
|
|
|
|
|
def read_head_best() -> float:
|
|
try:
|
|
raw = git("show", "HEAD:BEST_PRACTICE.json")
|
|
except subprocess.CalledProcessError:
|
|
return float("inf")
|
|
try:
|
|
parsed = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
return float("inf")
|
|
return float(parsed.get("best_mean_rmsd_100", float("inf")))
|
|
|
|
|
|
def fail(msg: str) -> int:
|
|
print(f"[pre-commit] {msg}", file=sys.stderr)
|
|
return 1
|
|
|
|
|
|
def staged_train_py_text() -> str:
|
|
return subprocess.check_output(["git", "show", ":train.py"], cwd=ROOT, text=True)
|
|
|
|
|
|
def enforce_random_time_requirement(train_src: str) -> None:
|
|
forbidden = ["t_zero_prob", "fixed_time", "torch.zeros((), device=cpu, dtype=torch.float32)"]
|
|
for token in forbidden:
|
|
if token in train_src:
|
|
raise RuntimeError(
|
|
f"Flow-matching constraint violation: found forbidden token '{token}'. "
|
|
"Training time must remain random."
|
|
)
|
|
|
|
|
|
def current_branch() -> str:
|
|
return git("rev-parse", "--abbrev-ref", "HEAD")
|
|
|
|
|
|
def is_main_branch() -> bool:
|
|
return current_branch() == "main"
|
|
|
|
|
|
def get_orig_head() -> str | None:
|
|
try:
|
|
return git("rev-parse", "--verify", "ORIG_HEAD")
|
|
except subprocess.CalledProcessError:
|
|
return None
|
|
|
|
|
|
def read_best_from_ref(ref: str) -> float:
|
|
try:
|
|
raw = git("show", f"{ref}:BEST_PRACTICE.json")
|
|
except subprocess.CalledProcessError:
|
|
return float("inf")
|
|
try:
|
|
parsed = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
return float("inf")
|
|
return float(parsed.get("best_mean_rmsd_100", float("inf")))
|
|
|
|
|
|
def run_main_gate(previous_best: float, mode: str) -> int:
|
|
prefix = "[post-merge]" if mode == "post-merge" else "[pre-commit]"
|
|
latest = read_json(LATEST)
|
|
latest_rmsd = float(latest["mean_rmsd_100"])
|
|
latest_runs = int(latest["num_runs"])
|
|
if latest_runs != 100:
|
|
return fail(f"{prefix} num_runs must be 100, got {latest_runs}.")
|
|
if not (latest_rmsd < previous_best):
|
|
return fail(
|
|
f"{prefix} No improvement: latest={latest_rmsd:.6f}, previous_best={previous_best:.6f}. "
|
|
"main integration with train.py requires strict improvement."
|
|
)
|
|
best_report = {
|
|
"best_mean_rmsd_100": latest_rmsd,
|
|
"num_runs": latest_runs,
|
|
"timestamp_utc": latest.get("timestamp_utc", ""),
|
|
"command": latest.get("command", ""),
|
|
"notes": "Auto-updated by pre-commit from reports/latest_eval.json.",
|
|
"updated_by_commit": "pending",
|
|
"best_train_mse": latest.get("best_train_mse"),
|
|
}
|
|
write_json(BEST, best_report)
|
|
shutil.copy2(LATEST_MODEL, BEST_MODEL)
|
|
subprocess.check_call([sys.executable, str(ROOT / "scripts" / "update_best_artifacts.py")], cwd=ROOT)
|
|
if mode == "pre-commit":
|
|
subprocess.check_call(["git", "add", str(BEST)], cwd=ROOT)
|
|
subprocess.check_call(["git", "add", "reports/trajectories"], cwd=ROOT)
|
|
print(
|
|
f"{prefix} PASS: improved mean_rmsd_100 {previous_best:.6f} -> {latest_rmsd:.6f}; "
|
|
"BEST_PRACTICE.json refreshed; checkpoints (*.pt) and *.sdf stay local (gitignored, not staged).",
|
|
file=sys.stderr,
|
|
)
|
|
return 0
|
|
|
|
|
|
def main() -> int:
|
|
hook_stage = os.environ.get("PRE_COMMIT_HOOK_STAGE", "pre-commit")
|
|
if hook_stage == "post-merge":
|
|
if not is_main_branch():
|
|
return 0
|
|
orig_head = get_orig_head()
|
|
if not orig_head:
|
|
return 0
|
|
changed = changed_files_since(orig_head, "HEAD")
|
|
if "train.py" not in changed:
|
|
return 0
|
|
train_src = (ROOT / "train.py").read_text(encoding="utf-8")
|
|
try:
|
|
enforce_random_time_requirement(train_src)
|
|
except RuntimeError as e:
|
|
return fail(f"[post-merge] {e}")
|
|
if "reports/latest_eval.json" not in changed:
|
|
return fail("[post-merge] train.py merged to main: reports/latest_eval.json must be included.")
|
|
if "README.md" not in changed:
|
|
return fail("[post-merge] train.py merged to main: README.md attempt log update is required.")
|
|
if not LATEST_MODEL.exists():
|
|
return fail("[post-merge] Missing artifacts/latest_eval_best_model.pt from merged training run.")
|
|
previous_best = read_best_from_ref(orig_head)
|
|
return run_main_gate(previous_best=previous_best, mode="post-merge")
|
|
|
|
staged = staged_files()
|
|
if "train.py" not in staged:
|
|
return 0
|
|
try:
|
|
train_src = staged_train_py_text()
|
|
except subprocess.CalledProcessError:
|
|
train_src = (ROOT / "train.py").read_text(encoding="utf-8")
|
|
try:
|
|
enforce_random_time_requirement(train_src)
|
|
except RuntimeError as e:
|
|
return fail(str(e))
|
|
if not is_main_branch():
|
|
print(
|
|
"[pre-commit] train.py staged on a feature branch: "
|
|
"skipping mean_rmsd_100 gate (runs on main only at commit/merge time).",
|
|
file=sys.stderr,
|
|
)
|
|
return 0
|
|
if "reports/latest_eval.json" not in staged:
|
|
return fail("train.py changed: stage reports/latest_eval.json too.")
|
|
if "README.md" not in staged:
|
|
return fail("train.py changed: stage README.md with attempt log update.")
|
|
if not LATEST_MODEL.exists():
|
|
return fail("Missing artifacts/latest_eval_best_model.pt from training run.")
|
|
previous_best = read_head_best()
|
|
return run_main_gate(previous_best=previous_best, mode="pre-commit")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|