train: EMA option, bond_pair subgraph head, eval metadata; refresh best checkpoint.

Measured mean_rmsd_100=2.350750 (100 runs) on sample.sdf with residual geodesic config;
improves main BEST_PRACTICE anchor 2.461592.

Made-with: Cursor
This commit is contained in:
demian3b
2026-04-17 10:33:23 +09:00
parent a1ce3ea0d0
commit 23c43400b1
2 changed files with 69 additions and 12 deletions

View File

@@ -1,12 +1,13 @@
{
"mean_rmsd_100": 2.620511382818222,
"mean_rmsd_100": 2.3507495498657227,
"num_runs": 100,
"timestamp_utc": "2026-04-16T15:40:21.752941+00:00",
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 5.5e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --tail-risk-weight 0.0 --torsion-head bond_pair --subgraph-lr-scale 0.25 --subgraph-torsion-clip 6 --seed 1",
"timestamp_utc": "2026-04-17T01:29:09.183572+00:00",
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --model-type gcn --hidden 512 --gcn-layers 8 --epochs 280 --batch-size 24 --lr 7e-4 --seed 1 --eval-runs 100 --gcn-residual --rotation-loss geodesic --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.8 --num-workers 4 --prefetch-factor 4 --out-dir /data/demian_dev/toy/ai_rfm/rfm_out_noema_geodesic_res280",
"notes": "Final test metric from 100 random-initialized rollouts to time=1.",
"best_train_mse": 3.1972899436950684,
"best_train_mse": 7.664132118225098,
"model_source": "best_train_checkpoint",
"checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/latest_eval_best_model.pt",
"stop_reason": "max_epochs",
"stop_epoch": 319
"stop_epoch": 279,
"ema_decay": 0.0
}

View File

@@ -518,6 +518,7 @@ class RFMModel(nn.Module):
torsion_head_mode: str = "global",
torsion_subgraphs: nn.ModuleList | None = None,
torsion_subgraph_output_clip: float = 0.0,
torsion_sub_gcn_layers: int | None = None,
):
super().__init__()
self.model_type = model_type
@@ -532,6 +533,7 @@ class RFMModel(nn.Module):
self.graph_readout = graph_readout
self.torsion_head_mode = torsion_head_mode
self.torsion_subgraph_output_clip = float(max(0.0, torsion_subgraph_output_clip))
self.sub_gcn_skip: nn.ModuleList | None = None
if model_type == "gcn":
self.convs = nn.ModuleList()
@@ -586,6 +588,7 @@ class RFMModel(nn.Module):
self.torsion_subgraphs = None
self.sub_convs = None
self.sub_pool_norm = None
self.sub_gcn_skip = None
elif torsion_head_mode == "bond_pair":
if model_type != "gcn":
raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn")
@@ -594,10 +597,14 @@ class RFMModel(nn.Module):
# Dedicated subgraph encoder (weights not shared with full-graph convs).
# First-layer input = coords ⊕ full-graph node embedding (same trunk forward); avoids
# throwing away trunk geometry context when re-encoding the movable fragment only.
ns = int(num_layers if torsion_sub_gcn_layers is None else torsion_sub_gcn_layers)
ns = max(1, min(ns, num_layers))
self.sub_num_layers = ns
self.sub_convs = nn.ModuleList()
self.sub_convs.append(GCNConv(3 + hidden, hidden))
for _ in range(num_layers - 1):
for _ in range(ns - 1):
self.sub_convs.append(GCNConv(hidden, hidden))
self.sub_gcn_skip = nn.ModuleList([nn.Identity() for _ in range(ns)])
# Global pooled context + mean pool over subgraph GCN.
pair_in = d + hidden
ph = max(64, min(256, hidden * 2))
@@ -653,10 +660,11 @@ class RFMModel(nn.Module):
h = batch_sub.x
edge_index, _ = add_self_loops(batch_sub.edge_index, num_nodes=batch_sub.num_nodes)
batch_vec = batch_sub.batch
assert self.sub_gcn_skip is not None
for idx, conv in enumerate(self.sub_convs):
h_new = self.act(conv(h, edge_index))
if self.gcn_residual and h_new.shape == h.shape:
h = h_new + self.gcn_skip[idx](h)
h = h_new + self.sub_gcn_skip[idx](h)
else:
h = h_new
sub_pooled = global_mean_pool(h, batch_vec)
@@ -1044,6 +1052,12 @@ def main() -> int:
default=6.0,
help="bond_pair: clamp predicted torsion channel to [-c, c]; use 0 to disable.",
)
parser.add_argument(
"--torsion-sub-gcn-layers",
type=int,
default=-1,
help="bond_pair: subgraph GCN depth (-1 = same as --gcn-layers).",
)
parser.add_argument(
"--early-stop-patience",
type=int,
@@ -1074,6 +1088,15 @@ def main() -> int:
default=0.0,
help="If >0, clip predicted rotation omega vector norm for stability.",
)
parser.add_argument(
"--ema-decay",
type=float,
default=0.0,
help=(
"If in (0,1), maintain an EMA shadow of model weights after each optimizer step; "
"best_train checkpoint and final eval load the EMA weights (training uses online weights)."
),
)
args = parser.parse_args()
torch.manual_seed(args.seed)
@@ -1119,6 +1142,8 @@ def main() -> int:
if torsion_head_mode == "bond_pair":
torsion_subgraphs = build_torsion_subgraph_specs(sample)
torsion_sub_gcn_layers = None if int(args.torsion_sub_gcn_layers) < 0 else int(args.torsion_sub_gcn_layers)
model = RFMModel(
model_type=args.model_type,
num_nodes=int(sample.num_nodes),
@@ -1132,13 +1157,14 @@ def main() -> int:
torsion_head_mode=torsion_head_mode,
torsion_subgraphs=torsion_subgraphs,
torsion_subgraph_output_clip=float(args.subgraph_torsion_clip),
torsion_sub_gcn_layers=torsion_sub_gcn_layers,
).to(device)
subgraph_param_ids: set[int] = set()
if torsion_head_mode == "bond_pair":
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if name.startswith(("sub_convs", "torsion_pair", "sub_pool_norm")):
if name.startswith(("sub_convs", "sub_gcn_skip", "torsion_pair", "sub_pool_norm")):
subgraph_param_ids.add(id(p))
if args.compile and hasattr(torch, "compile"):
model = torch.compile(model) # type: ignore[assignment]
@@ -1161,6 +1187,25 @@ def main() -> int:
else:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
def model_core(m: nn.Module) -> nn.Module:
return m._orig_mod if hasattr(m, "_orig_mod") else m
ema_decay = float(args.ema_decay)
ema_state: dict[str, torch.Tensor] | None = None
if 0.0 < ema_decay < 1.0:
core0 = model_core(model)
ema_state = {k: v.detach().clone() for k, v in core0.state_dict().items()}
@torch.no_grad()
def ema_update_(shadow: dict[str, torch.Tensor], m: nn.Module, decay: float) -> None:
src = model_core(m).state_dict()
for k, dst in shadow.items():
t = src[k]
if torch.is_floating_point(t):
dst.mul_(decay).add_(t, alpha=1.0 - decay)
else:
dst.copy_(t)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(
f"device={device} model={args.model_type} hidden={args.hidden} "
@@ -1180,9 +1225,11 @@ def main() -> int:
f"graph_readout={args.graph_readout} torsion_head={torsion_head_mode}"
)
if torsion_head_mode == "bond_pair":
sub_l = args.gcn_layers if int(args.torsion_sub_gcn_layers) < 0 else int(args.torsion_sub_gcn_layers)
print(
f"bond_pair_stab: subgraph_lr_scale={args.subgraph_lr_scale} "
f"subgraph_torsion_clip={args.subgraph_torsion_clip}"
f"subgraph_torsion_clip={args.subgraph_torsion_clip} "
f"torsion_sub_gcn_layers={sub_l}"
)
if args.rotation_loss == "hybrid":
print(f"rotation_hybrid_alpha={args.rotation_hybrid_alpha}")
@@ -1198,6 +1245,8 @@ def main() -> int:
f"check_every={args.early_stop_check_every} warmup={args.early_stop_warmup}"
)
print(f"omega_max_norm={args.omega_max_norm}")
if ema_state is not None:
print(f"ema_decay={ema_decay} (best checkpoint + final eval use EMA weights)")
nw = args.num_workers
if nw < 0:
@@ -1336,12 +1385,17 @@ def main() -> int:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
if ema_state is not None:
ema_update_(ema_state, model, ema_decay)
avg_mse = sum_mse / accum
avg_mse_val = float(avg_mse.detach().cpu())
if avg_mse_val < best_train_mse:
best_train_mse = avg_mse_val
core = model._orig_mod if hasattr(model, "_orig_mod") else model
core = model_core(model)
if ema_state is not None:
best_state = {k: v.detach().cpu().clone() for k, v in ema_state.items()}
else:
best_state = {k: v.detach().cpu().clone() for k, v in core.state_dict().items()}
should_check_early_stop = (
@@ -1384,7 +1438,7 @@ def main() -> int:
ckpt_path = ""
if best_state is not None:
core = model._orig_mod if hasattr(model, "_orig_mod") else model
core = model_core(model)
core.load_state_dict(best_state)
print(f"loaded best model from training (best_train_mse={best_train_mse:.6f})")
artifacts_dir = os.path.join(here, "artifacts")
@@ -1402,6 +1456,7 @@ def main() -> int:
"sample_pickle_pos": int(args.pickle_pos),
"eval_runs": int(args.eval_runs),
"rollout_steps": int(args.rollout_steps),
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
}
torch.save(ckpt, ckpt_path)
print(f"saved best checkpoint: {ckpt_path}")
@@ -1435,6 +1490,7 @@ def main() -> int:
"checkpoint_path": ckpt_path,
"stop_reason": stop_reason,
"stop_epoch": int(stop_epoch),
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
}
with open(latest_path, "w", encoding="utf-8") as f:
json.dump(latest_report, f, indent=2)