fix: save RFM ckpt metadata; reload model for trajectory generation.

Restores gcn_residual/graph_readout parity with training in
update_best_artifacts; migration reads BEST_PRACTICE command when
ckpt keys are missing.

Made-with: Cursor
This commit is contained in:
demian3b
2026-04-17 13:38:42 +09:00
parent 16136c7df8
commit d894949ef9
2 changed files with 101 additions and 11 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import json
import re
import sys
from copy import deepcopy
from pathlib import Path
@@ -15,6 +16,8 @@ from train import ( # noqa: E402
Batch,
RFMModel,
RigidTorsionState,
build_torsion_subgraph_specs,
clip_vector_norm_batch,
generate_random_quaternion,
generate_random_torsion_angles,
generate_random_trans,
@@ -29,6 +32,81 @@ CKPT = ROOT / "artifacts" / "best_model.pt"
OUT_DIR = ROOT / "reports" / "trajectories"
def _read_json_command(path: Path) -> str:
if not path.exists():
return ""
try:
return str(json.loads(path.read_text(encoding="utf-8")).get("command", ""))
except (json.JSONDecodeError, OSError):
return ""
def _migration_command_string() -> str:
"""Prefer BEST_PRACTICE (anchored best run); avoid latest_eval when it is from a different experiment."""
bp = ROOT / "BEST_PRACTICE.json"
cmd = _read_json_command(bp)
if cmd.strip():
return cmd
return _read_json_command(ROOT / "reports" / "latest_eval.json")
def _infer_gcn_residual(ckpt: dict) -> bool:
if "gcn_residual" in ckpt:
return bool(ckpt["gcn_residual"])
return "--gcn-residual" in _migration_command_string()
def _infer_graph_readout(ckpt: dict) -> str:
if "graph_readout" in ckpt:
return str(ckpt["graph_readout"])
cmd = _migration_command_string()
if "--graph-readout attention" in cmd:
return "attention"
return "mean"
def _infer_torsion_head(ckpt: dict) -> str:
if "torsion_head" in ckpt:
return str(ckpt["torsion_head"])
cmd = _migration_command_string()
if "--torsion-head bond_pair" in cmd:
return "bond_pair"
return "global"
def _rfm_kwargs_from_ckpt(ckpt: dict, sample) -> dict:
"""Match train-time RFMModel kwargs; old checkpoints infer flags from latest_eval command."""
torsion_head = _infer_torsion_head(ckpt)
torsion_subgraphs = None
if torsion_head == "bond_pair":
torsion_subgraphs = build_torsion_subgraph_specs(sample)
tsgl = int(ckpt.get("torsion_sub_gcn_layers", -1))
torsion_sub_gcn_layers = None if tsgl < 0 else tsgl
cmd = _migration_command_string()
ch_ln = ckpt.get("channel_layernorm")
if ch_ln is None:
ch_ln = "--channel-layernorm" in cmd
hml = ckpt.get("head_mlp_layers")
if hml is None:
m = re.search(r"--head-mlp-layers\s+(\d+)", cmd)
hml = int(m.group(1)) if m else 1
return {
"model_type": ckpt["model_type"],
"num_nodes": int(ckpt["num_nodes"]),
"num_torsions": int(ckpt["num_torsions"]),
"hidden": int(ckpt["hidden"]),
"num_layers": int(ckpt["num_layers"]),
"gcn_residual": _infer_gcn_residual(ckpt),
"channel_layernorm": bool(ch_ln),
"head_mlp_layers": int(hml),
"graph_readout": _infer_graph_readout(ckpt),
"torsion_head_mode": torsion_head,
"torsion_subgraphs": torsion_subgraphs,
"torsion_subgraph_output_clip": float(ckpt.get("subgraph_torsion_clip", 0.0)),
"torsion_sub_gcn_layers": torsion_sub_gcn_layers,
}
def load_sample_from_ckpt(ckpt: dict):
pkl = ckpt.get("sample_pickle_path", "")
if pkl:
@@ -39,7 +117,9 @@ def load_sample_from_ckpt(ckpt: dict):
raise RuntimeError("Checkpoint does not contain sample source path.")
def rollout_states(model, sample, init_state, device, n_frames: int):
def rollout_states(
model, sample, init_state, device, n_frames: int, omega_max_norm: float = 0.0
):
state = RigidTorsionState(
center=init_state.center.to(device),
rot=init_state.rot.to(device),
@@ -68,7 +148,7 @@ def rollout_states(model, sample, init_state, device, n_frames: int):
with torch.no_grad():
v = model(batch, tb)
vc = v.center.squeeze(0)
vw = v.rot_omega.squeeze(0)
vw = clip_vector_norm_batch(v.rot_omega, omega_max_norm).squeeze(0)
vt = v.torsion.squeeze(0)
dt = 1.0 / max(1, n_frames - 1)
state = RigidTorsionState(
@@ -82,20 +162,19 @@ def rollout_states(model, sample, init_state, device, n_frames: int):
def main() -> int:
if not CKPT.exists():
raise FileNotFoundError(f"Missing checkpoint: {CKPT}")
ckpt = torch.load(CKPT, map_location="cpu")
ckpt = torch.load(CKPT, map_location="cpu", weights_only=False)
sample = prepare_sample(load_sample_from_ckpt(ckpt))
model = RFMModel(
model_type=ckpt["model_type"],
num_nodes=int(ckpt["num_nodes"]),
num_torsions=int(ckpt["num_torsions"]),
hidden=int(ckpt["hidden"]),
num_layers=int(ckpt["num_layers"]),
)
model = RFMModel(**_rfm_kwargs_from_ckpt(ckpt, sample))
model.load_state_dict(ckpt["state_dict"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()
omega_max_norm = ckpt.get("omega_max_norm")
if omega_max_norm is None:
m = re.search(r"--omega-max-norm\s+([0-9.+-eE]+)", _migration_command_string())
omega_max_norm = float(m.group(1)) if m else 0.0
omega_max_norm = float(omega_max_norm)
n_traj = 6
n_frames = int(ckpt.get("rollout_steps", 100))
OUT_DIR.mkdir(parents=True, exist_ok=True)
@@ -106,7 +185,9 @@ def main() -> int:
rot=generate_rotation_matrix(generate_random_quaternion()),
torsion=generate_random_torsion_angles(sample.bond_endpoints.shape[0]),
)
states = rollout_states(model, sample, init_state, device, n_frames=n_frames)
states = rollout_states(
model, sample, init_state, device, n_frames=n_frames, omega_max_norm=omega_max_norm
)
out_path = OUT_DIR / f"trajectory_{i:02d}.sdf"
save_video_visualization(sample, states, str(out_path))
files.append(out_path.name)

View File

@@ -1457,6 +1457,15 @@ def main() -> int:
"eval_runs": int(args.eval_runs),
"rollout_steps": int(args.rollout_steps),
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
# RFMModel forward parity (required for trajectory / tooling reloads)
"gcn_residual": bool(args.gcn_residual),
"channel_layernorm": bool(args.channel_layernorm),
"head_mlp_layers": int(args.head_mlp_layers),
"graph_readout": str(args.graph_readout),
"torsion_head": str(torsion_head_mode),
"subgraph_torsion_clip": float(args.subgraph_torsion_clip),
"torsion_sub_gcn_layers": int(args.torsion_sub_gcn_layers),
"omega_max_norm": float(args.omega_max_norm),
}
torch.save(ckpt, ckpt_path)
print(f"saved best checkpoint: {ckpt_path}")