fix: save RFM ckpt metadata; reload model for trajectory generation.
Restores gcn_residual/graph_readout parity with training in update_best_artifacts; migration reads BEST_PRACTICE command when ckpt keys are missing. Made-with: Cursor
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
@@ -15,6 +16,8 @@ from train import ( # noqa: E402
|
||||
Batch,
|
||||
RFMModel,
|
||||
RigidTorsionState,
|
||||
build_torsion_subgraph_specs,
|
||||
clip_vector_norm_batch,
|
||||
generate_random_quaternion,
|
||||
generate_random_torsion_angles,
|
||||
generate_random_trans,
|
||||
@@ -29,6 +32,81 @@ CKPT = ROOT / "artifacts" / "best_model.pt"
|
||||
OUT_DIR = ROOT / "reports" / "trajectories"
|
||||
|
||||
|
||||
def _read_json_command(path: Path) -> str:
|
||||
if not path.exists():
|
||||
return ""
|
||||
try:
|
||||
return str(json.loads(path.read_text(encoding="utf-8")).get("command", ""))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return ""
|
||||
|
||||
|
||||
def _migration_command_string() -> str:
|
||||
"""Prefer BEST_PRACTICE (anchored best run); avoid latest_eval when it is from a different experiment."""
|
||||
bp = ROOT / "BEST_PRACTICE.json"
|
||||
cmd = _read_json_command(bp)
|
||||
if cmd.strip():
|
||||
return cmd
|
||||
return _read_json_command(ROOT / "reports" / "latest_eval.json")
|
||||
|
||||
|
||||
def _infer_gcn_residual(ckpt: dict) -> bool:
|
||||
if "gcn_residual" in ckpt:
|
||||
return bool(ckpt["gcn_residual"])
|
||||
return "--gcn-residual" in _migration_command_string()
|
||||
|
||||
|
||||
def _infer_graph_readout(ckpt: dict) -> str:
|
||||
if "graph_readout" in ckpt:
|
||||
return str(ckpt["graph_readout"])
|
||||
cmd = _migration_command_string()
|
||||
if "--graph-readout attention" in cmd:
|
||||
return "attention"
|
||||
return "mean"
|
||||
|
||||
|
||||
def _infer_torsion_head(ckpt: dict) -> str:
|
||||
if "torsion_head" in ckpt:
|
||||
return str(ckpt["torsion_head"])
|
||||
cmd = _migration_command_string()
|
||||
if "--torsion-head bond_pair" in cmd:
|
||||
return "bond_pair"
|
||||
return "global"
|
||||
|
||||
|
||||
def _rfm_kwargs_from_ckpt(ckpt: dict, sample) -> dict:
|
||||
"""Match train-time RFMModel kwargs; old checkpoints infer flags from latest_eval command."""
|
||||
torsion_head = _infer_torsion_head(ckpt)
|
||||
torsion_subgraphs = None
|
||||
if torsion_head == "bond_pair":
|
||||
torsion_subgraphs = build_torsion_subgraph_specs(sample)
|
||||
tsgl = int(ckpt.get("torsion_sub_gcn_layers", -1))
|
||||
torsion_sub_gcn_layers = None if tsgl < 0 else tsgl
|
||||
cmd = _migration_command_string()
|
||||
ch_ln = ckpt.get("channel_layernorm")
|
||||
if ch_ln is None:
|
||||
ch_ln = "--channel-layernorm" in cmd
|
||||
hml = ckpt.get("head_mlp_layers")
|
||||
if hml is None:
|
||||
m = re.search(r"--head-mlp-layers\s+(\d+)", cmd)
|
||||
hml = int(m.group(1)) if m else 1
|
||||
return {
|
||||
"model_type": ckpt["model_type"],
|
||||
"num_nodes": int(ckpt["num_nodes"]),
|
||||
"num_torsions": int(ckpt["num_torsions"]),
|
||||
"hidden": int(ckpt["hidden"]),
|
||||
"num_layers": int(ckpt["num_layers"]),
|
||||
"gcn_residual": _infer_gcn_residual(ckpt),
|
||||
"channel_layernorm": bool(ch_ln),
|
||||
"head_mlp_layers": int(hml),
|
||||
"graph_readout": _infer_graph_readout(ckpt),
|
||||
"torsion_head_mode": torsion_head,
|
||||
"torsion_subgraphs": torsion_subgraphs,
|
||||
"torsion_subgraph_output_clip": float(ckpt.get("subgraph_torsion_clip", 0.0)),
|
||||
"torsion_sub_gcn_layers": torsion_sub_gcn_layers,
|
||||
}
|
||||
|
||||
|
||||
def load_sample_from_ckpt(ckpt: dict):
|
||||
pkl = ckpt.get("sample_pickle_path", "")
|
||||
if pkl:
|
||||
@@ -39,7 +117,9 @@ def load_sample_from_ckpt(ckpt: dict):
|
||||
raise RuntimeError("Checkpoint does not contain sample source path.")
|
||||
|
||||
|
||||
def rollout_states(model, sample, init_state, device, n_frames: int):
|
||||
def rollout_states(
|
||||
model, sample, init_state, device, n_frames: int, omega_max_norm: float = 0.0
|
||||
):
|
||||
state = RigidTorsionState(
|
||||
center=init_state.center.to(device),
|
||||
rot=init_state.rot.to(device),
|
||||
@@ -68,7 +148,7 @@ def rollout_states(model, sample, init_state, device, n_frames: int):
|
||||
with torch.no_grad():
|
||||
v = model(batch, tb)
|
||||
vc = v.center.squeeze(0)
|
||||
vw = v.rot_omega.squeeze(0)
|
||||
vw = clip_vector_norm_batch(v.rot_omega, omega_max_norm).squeeze(0)
|
||||
vt = v.torsion.squeeze(0)
|
||||
dt = 1.0 / max(1, n_frames - 1)
|
||||
state = RigidTorsionState(
|
||||
@@ -82,20 +162,19 @@ def rollout_states(model, sample, init_state, device, n_frames: int):
|
||||
def main() -> int:
|
||||
if not CKPT.exists():
|
||||
raise FileNotFoundError(f"Missing checkpoint: {CKPT}")
|
||||
ckpt = torch.load(CKPT, map_location="cpu")
|
||||
ckpt = torch.load(CKPT, map_location="cpu", weights_only=False)
|
||||
sample = prepare_sample(load_sample_from_ckpt(ckpt))
|
||||
|
||||
model = RFMModel(
|
||||
model_type=ckpt["model_type"],
|
||||
num_nodes=int(ckpt["num_nodes"]),
|
||||
num_torsions=int(ckpt["num_torsions"]),
|
||||
hidden=int(ckpt["hidden"]),
|
||||
num_layers=int(ckpt["num_layers"]),
|
||||
)
|
||||
model = RFMModel(**_rfm_kwargs_from_ckpt(ckpt, sample))
|
||||
model.load_state_dict(ckpt["state_dict"])
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device).eval()
|
||||
|
||||
omega_max_norm = ckpt.get("omega_max_norm")
|
||||
if omega_max_norm is None:
|
||||
m = re.search(r"--omega-max-norm\s+([0-9.+-eE]+)", _migration_command_string())
|
||||
omega_max_norm = float(m.group(1)) if m else 0.0
|
||||
omega_max_norm = float(omega_max_norm)
|
||||
n_traj = 6
|
||||
n_frames = int(ckpt.get("rollout_steps", 100))
|
||||
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
@@ -106,7 +185,9 @@ def main() -> int:
|
||||
rot=generate_rotation_matrix(generate_random_quaternion()),
|
||||
torsion=generate_random_torsion_angles(sample.bond_endpoints.shape[0]),
|
||||
)
|
||||
states = rollout_states(model, sample, init_state, device, n_frames=n_frames)
|
||||
states = rollout_states(
|
||||
model, sample, init_state, device, n_frames=n_frames, omega_max_norm=omega_max_norm
|
||||
)
|
||||
out_path = OUT_DIR / f"trajectory_{i:02d}.sdf"
|
||||
save_video_visualization(sample, states, str(out_path))
|
||||
files.append(out_path.name)
|
||||
|
||||
9
train.py
9
train.py
@@ -1457,6 +1457,15 @@ def main() -> int:
|
||||
"eval_runs": int(args.eval_runs),
|
||||
"rollout_steps": int(args.rollout_steps),
|
||||
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
|
||||
# RFMModel forward parity (required for trajectory / tooling reloads)
|
||||
"gcn_residual": bool(args.gcn_residual),
|
||||
"channel_layernorm": bool(args.channel_layernorm),
|
||||
"head_mlp_layers": int(args.head_mlp_layers),
|
||||
"graph_readout": str(args.graph_readout),
|
||||
"torsion_head": str(torsion_head_mode),
|
||||
"subgraph_torsion_clip": float(args.subgraph_torsion_clip),
|
||||
"torsion_sub_gcn_layers": int(args.torsion_sub_gcn_layers),
|
||||
"omega_max_norm": float(args.omega_max_norm),
|
||||
}
|
||||
torch.save(ckpt, ckpt_path)
|
||||
print(f"saved best checkpoint: {ckpt_path}")
|
||||
|
||||
Reference in New Issue
Block a user