fix: save RFM ckpt metadata; reload model for trajectory generation.
Restores gcn_residual/graph_readout parity with training in update_best_artifacts; migration reads BEST_PRACTICE command when ckpt keys are missing. Made-with: Cursor
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -15,6 +16,8 @@ from train import ( # noqa: E402
|
|||||||
Batch,
|
Batch,
|
||||||
RFMModel,
|
RFMModel,
|
||||||
RigidTorsionState,
|
RigidTorsionState,
|
||||||
|
build_torsion_subgraph_specs,
|
||||||
|
clip_vector_norm_batch,
|
||||||
generate_random_quaternion,
|
generate_random_quaternion,
|
||||||
generate_random_torsion_angles,
|
generate_random_torsion_angles,
|
||||||
generate_random_trans,
|
generate_random_trans,
|
||||||
@@ -29,6 +32,81 @@ CKPT = ROOT / "artifacts" / "best_model.pt"
|
|||||||
OUT_DIR = ROOT / "reports" / "trajectories"
|
OUT_DIR = ROOT / "reports" / "trajectories"
|
||||||
|
|
||||||
|
|
||||||
|
def _read_json_command(path: Path) -> str:
|
||||||
|
if not path.exists():
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
return str(json.loads(path.read_text(encoding="utf-8")).get("command", ""))
|
||||||
|
except (json.JSONDecodeError, OSError):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _migration_command_string() -> str:
|
||||||
|
"""Prefer BEST_PRACTICE (anchored best run); avoid latest_eval when it is from a different experiment."""
|
||||||
|
bp = ROOT / "BEST_PRACTICE.json"
|
||||||
|
cmd = _read_json_command(bp)
|
||||||
|
if cmd.strip():
|
||||||
|
return cmd
|
||||||
|
return _read_json_command(ROOT / "reports" / "latest_eval.json")
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_gcn_residual(ckpt: dict) -> bool:
|
||||||
|
if "gcn_residual" in ckpt:
|
||||||
|
return bool(ckpt["gcn_residual"])
|
||||||
|
return "--gcn-residual" in _migration_command_string()
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_graph_readout(ckpt: dict) -> str:
|
||||||
|
if "graph_readout" in ckpt:
|
||||||
|
return str(ckpt["graph_readout"])
|
||||||
|
cmd = _migration_command_string()
|
||||||
|
if "--graph-readout attention" in cmd:
|
||||||
|
return "attention"
|
||||||
|
return "mean"
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_torsion_head(ckpt: dict) -> str:
|
||||||
|
if "torsion_head" in ckpt:
|
||||||
|
return str(ckpt["torsion_head"])
|
||||||
|
cmd = _migration_command_string()
|
||||||
|
if "--torsion-head bond_pair" in cmd:
|
||||||
|
return "bond_pair"
|
||||||
|
return "global"
|
||||||
|
|
||||||
|
|
||||||
|
def _rfm_kwargs_from_ckpt(ckpt: dict, sample) -> dict:
|
||||||
|
"""Match train-time RFMModel kwargs; old checkpoints infer flags from latest_eval command."""
|
||||||
|
torsion_head = _infer_torsion_head(ckpt)
|
||||||
|
torsion_subgraphs = None
|
||||||
|
if torsion_head == "bond_pair":
|
||||||
|
torsion_subgraphs = build_torsion_subgraph_specs(sample)
|
||||||
|
tsgl = int(ckpt.get("torsion_sub_gcn_layers", -1))
|
||||||
|
torsion_sub_gcn_layers = None if tsgl < 0 else tsgl
|
||||||
|
cmd = _migration_command_string()
|
||||||
|
ch_ln = ckpt.get("channel_layernorm")
|
||||||
|
if ch_ln is None:
|
||||||
|
ch_ln = "--channel-layernorm" in cmd
|
||||||
|
hml = ckpt.get("head_mlp_layers")
|
||||||
|
if hml is None:
|
||||||
|
m = re.search(r"--head-mlp-layers\s+(\d+)", cmd)
|
||||||
|
hml = int(m.group(1)) if m else 1
|
||||||
|
return {
|
||||||
|
"model_type": ckpt["model_type"],
|
||||||
|
"num_nodes": int(ckpt["num_nodes"]),
|
||||||
|
"num_torsions": int(ckpt["num_torsions"]),
|
||||||
|
"hidden": int(ckpt["hidden"]),
|
||||||
|
"num_layers": int(ckpt["num_layers"]),
|
||||||
|
"gcn_residual": _infer_gcn_residual(ckpt),
|
||||||
|
"channel_layernorm": bool(ch_ln),
|
||||||
|
"head_mlp_layers": int(hml),
|
||||||
|
"graph_readout": _infer_graph_readout(ckpt),
|
||||||
|
"torsion_head_mode": torsion_head,
|
||||||
|
"torsion_subgraphs": torsion_subgraphs,
|
||||||
|
"torsion_subgraph_output_clip": float(ckpt.get("subgraph_torsion_clip", 0.0)),
|
||||||
|
"torsion_sub_gcn_layers": torsion_sub_gcn_layers,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_sample_from_ckpt(ckpt: dict):
|
def load_sample_from_ckpt(ckpt: dict):
|
||||||
pkl = ckpt.get("sample_pickle_path", "")
|
pkl = ckpt.get("sample_pickle_path", "")
|
||||||
if pkl:
|
if pkl:
|
||||||
@@ -39,7 +117,9 @@ def load_sample_from_ckpt(ckpt: dict):
|
|||||||
raise RuntimeError("Checkpoint does not contain sample source path.")
|
raise RuntimeError("Checkpoint does not contain sample source path.")
|
||||||
|
|
||||||
|
|
||||||
def rollout_states(model, sample, init_state, device, n_frames: int):
|
def rollout_states(
|
||||||
|
model, sample, init_state, device, n_frames: int, omega_max_norm: float = 0.0
|
||||||
|
):
|
||||||
state = RigidTorsionState(
|
state = RigidTorsionState(
|
||||||
center=init_state.center.to(device),
|
center=init_state.center.to(device),
|
||||||
rot=init_state.rot.to(device),
|
rot=init_state.rot.to(device),
|
||||||
@@ -68,7 +148,7 @@ def rollout_states(model, sample, init_state, device, n_frames: int):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
v = model(batch, tb)
|
v = model(batch, tb)
|
||||||
vc = v.center.squeeze(0)
|
vc = v.center.squeeze(0)
|
||||||
vw = v.rot_omega.squeeze(0)
|
vw = clip_vector_norm_batch(v.rot_omega, omega_max_norm).squeeze(0)
|
||||||
vt = v.torsion.squeeze(0)
|
vt = v.torsion.squeeze(0)
|
||||||
dt = 1.0 / max(1, n_frames - 1)
|
dt = 1.0 / max(1, n_frames - 1)
|
||||||
state = RigidTorsionState(
|
state = RigidTorsionState(
|
||||||
@@ -82,20 +162,19 @@ def rollout_states(model, sample, init_state, device, n_frames: int):
|
|||||||
def main() -> int:
|
def main() -> int:
|
||||||
if not CKPT.exists():
|
if not CKPT.exists():
|
||||||
raise FileNotFoundError(f"Missing checkpoint: {CKPT}")
|
raise FileNotFoundError(f"Missing checkpoint: {CKPT}")
|
||||||
ckpt = torch.load(CKPT, map_location="cpu")
|
ckpt = torch.load(CKPT, map_location="cpu", weights_only=False)
|
||||||
sample = prepare_sample(load_sample_from_ckpt(ckpt))
|
sample = prepare_sample(load_sample_from_ckpt(ckpt))
|
||||||
|
|
||||||
model = RFMModel(
|
model = RFMModel(**_rfm_kwargs_from_ckpt(ckpt, sample))
|
||||||
model_type=ckpt["model_type"],
|
|
||||||
num_nodes=int(ckpt["num_nodes"]),
|
|
||||||
num_torsions=int(ckpt["num_torsions"]),
|
|
||||||
hidden=int(ckpt["hidden"]),
|
|
||||||
num_layers=int(ckpt["num_layers"]),
|
|
||||||
)
|
|
||||||
model.load_state_dict(ckpt["state_dict"])
|
model.load_state_dict(ckpt["state_dict"])
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
model = model.to(device).eval()
|
model = model.to(device).eval()
|
||||||
|
|
||||||
|
omega_max_norm = ckpt.get("omega_max_norm")
|
||||||
|
if omega_max_norm is None:
|
||||||
|
m = re.search(r"--omega-max-norm\s+([0-9.+-eE]+)", _migration_command_string())
|
||||||
|
omega_max_norm = float(m.group(1)) if m else 0.0
|
||||||
|
omega_max_norm = float(omega_max_norm)
|
||||||
n_traj = 6
|
n_traj = 6
|
||||||
n_frames = int(ckpt.get("rollout_steps", 100))
|
n_frames = int(ckpt.get("rollout_steps", 100))
|
||||||
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -106,7 +185,9 @@ def main() -> int:
|
|||||||
rot=generate_rotation_matrix(generate_random_quaternion()),
|
rot=generate_rotation_matrix(generate_random_quaternion()),
|
||||||
torsion=generate_random_torsion_angles(sample.bond_endpoints.shape[0]),
|
torsion=generate_random_torsion_angles(sample.bond_endpoints.shape[0]),
|
||||||
)
|
)
|
||||||
states = rollout_states(model, sample, init_state, device, n_frames=n_frames)
|
states = rollout_states(
|
||||||
|
model, sample, init_state, device, n_frames=n_frames, omega_max_norm=omega_max_norm
|
||||||
|
)
|
||||||
out_path = OUT_DIR / f"trajectory_{i:02d}.sdf"
|
out_path = OUT_DIR / f"trajectory_{i:02d}.sdf"
|
||||||
save_video_visualization(sample, states, str(out_path))
|
save_video_visualization(sample, states, str(out_path))
|
||||||
files.append(out_path.name)
|
files.append(out_path.name)
|
||||||
|
|||||||
9
train.py
9
train.py
@@ -1457,6 +1457,15 @@ def main() -> int:
|
|||||||
"eval_runs": int(args.eval_runs),
|
"eval_runs": int(args.eval_runs),
|
||||||
"rollout_steps": int(args.rollout_steps),
|
"rollout_steps": int(args.rollout_steps),
|
||||||
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
|
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
|
||||||
|
# RFMModel forward parity (required for trajectory / tooling reloads)
|
||||||
|
"gcn_residual": bool(args.gcn_residual),
|
||||||
|
"channel_layernorm": bool(args.channel_layernorm),
|
||||||
|
"head_mlp_layers": int(args.head_mlp_layers),
|
||||||
|
"graph_readout": str(args.graph_readout),
|
||||||
|
"torsion_head": str(torsion_head_mode),
|
||||||
|
"subgraph_torsion_clip": float(args.subgraph_torsion_clip),
|
||||||
|
"torsion_sub_gcn_layers": int(args.torsion_sub_gcn_layers),
|
||||||
|
"omega_max_norm": float(args.omega_max_norm),
|
||||||
}
|
}
|
||||||
torch.save(ckpt, ckpt_path)
|
torch.save(ckpt, ckpt_path)
|
||||||
print(f"saved best checkpoint: {ckpt_path}")
|
print(f"saved best checkpoint: {ckpt_path}")
|
||||||
|
|||||||
Reference in New Issue
Block a user