train: EMA option, bond_pair subgraph head, eval metadata; refresh best checkpoint.

Measured mean_rmsd_100=2.350750 (100 runs) on sample.sdf with residual geodesic config;
improves main BEST_PRACTICE anchor 2.461592.

Made-with: Cursor
This commit is contained in:
demian3b
2026-04-17 10:33:23 +09:00
parent a1ce3ea0d0
commit 23c43400b1
2 changed files with 69 additions and 12 deletions

View File

@@ -1,12 +1,13 @@
{ {
"mean_rmsd_100": 2.620511382818222, "mean_rmsd_100": 2.3507495498657227,
"num_runs": 100, "num_runs": 100,
"timestamp_utc": "2026-04-16T15:40:21.752941+00:00", "timestamp_utc": "2026-04-17T01:29:09.183572+00:00",
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 5.5e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --tail-risk-weight 0.0 --torsion-head bond_pair --subgraph-lr-scale 0.25 --subgraph-torsion-clip 6 --seed 1", "command": "train.py --sdf /data/demian_dev/toy/sample.sdf --model-type gcn --hidden 512 --gcn-layers 8 --epochs 280 --batch-size 24 --lr 7e-4 --seed 1 --eval-runs 100 --gcn-residual --rotation-loss geodesic --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.8 --num-workers 4 --prefetch-factor 4 --out-dir /data/demian_dev/toy/ai_rfm/rfm_out_noema_geodesic_res280",
"notes": "Final test metric from 100 random-initialized rollouts to time=1.", "notes": "Final test metric from 100 random-initialized rollouts to time=1.",
"best_train_mse": 3.1972899436950684, "best_train_mse": 7.664132118225098,
"model_source": "best_train_checkpoint", "model_source": "best_train_checkpoint",
"checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/latest_eval_best_model.pt", "checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/latest_eval_best_model.pt",
"stop_reason": "max_epochs", "stop_reason": "max_epochs",
"stop_epoch": 319 "stop_epoch": 279,
"ema_decay": 0.0
} }

View File

@@ -518,6 +518,7 @@ class RFMModel(nn.Module):
torsion_head_mode: str = "global", torsion_head_mode: str = "global",
torsion_subgraphs: nn.ModuleList | None = None, torsion_subgraphs: nn.ModuleList | None = None,
torsion_subgraph_output_clip: float = 0.0, torsion_subgraph_output_clip: float = 0.0,
torsion_sub_gcn_layers: int | None = None,
): ):
super().__init__() super().__init__()
self.model_type = model_type self.model_type = model_type
@@ -532,6 +533,7 @@ class RFMModel(nn.Module):
self.graph_readout = graph_readout self.graph_readout = graph_readout
self.torsion_head_mode = torsion_head_mode self.torsion_head_mode = torsion_head_mode
self.torsion_subgraph_output_clip = float(max(0.0, torsion_subgraph_output_clip)) self.torsion_subgraph_output_clip = float(max(0.0, torsion_subgraph_output_clip))
self.sub_gcn_skip: nn.ModuleList | None = None
if model_type == "gcn": if model_type == "gcn":
self.convs = nn.ModuleList() self.convs = nn.ModuleList()
@@ -586,6 +588,7 @@ class RFMModel(nn.Module):
self.torsion_subgraphs = None self.torsion_subgraphs = None
self.sub_convs = None self.sub_convs = None
self.sub_pool_norm = None self.sub_pool_norm = None
self.sub_gcn_skip = None
elif torsion_head_mode == "bond_pair": elif torsion_head_mode == "bond_pair":
if model_type != "gcn": if model_type != "gcn":
raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn") raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn")
@@ -594,10 +597,14 @@ class RFMModel(nn.Module):
# Dedicated subgraph encoder (weights not shared with full-graph convs). # Dedicated subgraph encoder (weights not shared with full-graph convs).
# First-layer input = coords ⊕ full-graph node embedding (same trunk forward); avoids # First-layer input = coords ⊕ full-graph node embedding (same trunk forward); avoids
# throwing away trunk geometry context when re-encoding the movable fragment only. # throwing away trunk geometry context when re-encoding the movable fragment only.
ns = int(num_layers if torsion_sub_gcn_layers is None else torsion_sub_gcn_layers)
ns = max(1, min(ns, num_layers))
self.sub_num_layers = ns
self.sub_convs = nn.ModuleList() self.sub_convs = nn.ModuleList()
self.sub_convs.append(GCNConv(3 + hidden, hidden)) self.sub_convs.append(GCNConv(3 + hidden, hidden))
for _ in range(num_layers - 1): for _ in range(ns - 1):
self.sub_convs.append(GCNConv(hidden, hidden)) self.sub_convs.append(GCNConv(hidden, hidden))
self.sub_gcn_skip = nn.ModuleList([nn.Identity() for _ in range(ns)])
# Global pooled context + mean pool over subgraph GCN. # Global pooled context + mean pool over subgraph GCN.
pair_in = d + hidden pair_in = d + hidden
ph = max(64, min(256, hidden * 2)) ph = max(64, min(256, hidden * 2))
@@ -653,10 +660,11 @@ class RFMModel(nn.Module):
h = batch_sub.x h = batch_sub.x
edge_index, _ = add_self_loops(batch_sub.edge_index, num_nodes=batch_sub.num_nodes) edge_index, _ = add_self_loops(batch_sub.edge_index, num_nodes=batch_sub.num_nodes)
batch_vec = batch_sub.batch batch_vec = batch_sub.batch
assert self.sub_gcn_skip is not None
for idx, conv in enumerate(self.sub_convs): for idx, conv in enumerate(self.sub_convs):
h_new = self.act(conv(h, edge_index)) h_new = self.act(conv(h, edge_index))
if self.gcn_residual and h_new.shape == h.shape: if self.gcn_residual and h_new.shape == h.shape:
h = h_new + self.gcn_skip[idx](h) h = h_new + self.sub_gcn_skip[idx](h)
else: else:
h = h_new h = h_new
sub_pooled = global_mean_pool(h, batch_vec) sub_pooled = global_mean_pool(h, batch_vec)
@@ -1044,6 +1052,12 @@ def main() -> int:
default=6.0, default=6.0,
help="bond_pair: clamp predicted torsion channel to [-c, c]; use 0 to disable.", help="bond_pair: clamp predicted torsion channel to [-c, c]; use 0 to disable.",
) )
parser.add_argument(
"--torsion-sub-gcn-layers",
type=int,
default=-1,
help="bond_pair: subgraph GCN depth (-1 = same as --gcn-layers).",
)
parser.add_argument( parser.add_argument(
"--early-stop-patience", "--early-stop-patience",
type=int, type=int,
@@ -1074,6 +1088,15 @@ def main() -> int:
default=0.0, default=0.0,
help="If >0, clip predicted rotation omega vector norm for stability.", help="If >0, clip predicted rotation omega vector norm for stability.",
) )
parser.add_argument(
"--ema-decay",
type=float,
default=0.0,
help=(
"If in (0,1), maintain an EMA shadow of model weights after each optimizer step; "
"best_train checkpoint and final eval load the EMA weights (training uses online weights)."
),
)
args = parser.parse_args() args = parser.parse_args()
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
@@ -1119,6 +1142,8 @@ def main() -> int:
if torsion_head_mode == "bond_pair": if torsion_head_mode == "bond_pair":
torsion_subgraphs = build_torsion_subgraph_specs(sample) torsion_subgraphs = build_torsion_subgraph_specs(sample)
torsion_sub_gcn_layers = None if int(args.torsion_sub_gcn_layers) < 0 else int(args.torsion_sub_gcn_layers)
model = RFMModel( model = RFMModel(
model_type=args.model_type, model_type=args.model_type,
num_nodes=int(sample.num_nodes), num_nodes=int(sample.num_nodes),
@@ -1132,13 +1157,14 @@ def main() -> int:
torsion_head_mode=torsion_head_mode, torsion_head_mode=torsion_head_mode,
torsion_subgraphs=torsion_subgraphs, torsion_subgraphs=torsion_subgraphs,
torsion_subgraph_output_clip=float(args.subgraph_torsion_clip), torsion_subgraph_output_clip=float(args.subgraph_torsion_clip),
torsion_sub_gcn_layers=torsion_sub_gcn_layers,
).to(device) ).to(device)
subgraph_param_ids: set[int] = set() subgraph_param_ids: set[int] = set()
if torsion_head_mode == "bond_pair": if torsion_head_mode == "bond_pair":
for name, p in model.named_parameters(): for name, p in model.named_parameters():
if not p.requires_grad: if not p.requires_grad:
continue continue
if name.startswith(("sub_convs", "torsion_pair", "sub_pool_norm")): if name.startswith(("sub_convs", "sub_gcn_skip", "torsion_pair", "sub_pool_norm")):
subgraph_param_ids.add(id(p)) subgraph_param_ids.add(id(p))
if args.compile and hasattr(torch, "compile"): if args.compile and hasattr(torch, "compile"):
model = torch.compile(model) # type: ignore[assignment] model = torch.compile(model) # type: ignore[assignment]
@@ -1161,6 +1187,25 @@ def main() -> int:
else: else:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
def model_core(m: nn.Module) -> nn.Module:
return m._orig_mod if hasattr(m, "_orig_mod") else m
ema_decay = float(args.ema_decay)
ema_state: dict[str, torch.Tensor] | None = None
if 0.0 < ema_decay < 1.0:
core0 = model_core(model)
ema_state = {k: v.detach().clone() for k, v in core0.state_dict().items()}
@torch.no_grad()
def ema_update_(shadow: dict[str, torch.Tensor], m: nn.Module, decay: float) -> None:
src = model_core(m).state_dict()
for k, dst in shadow.items():
t = src[k]
if torch.is_floating_point(t):
dst.mul_(decay).add_(t, alpha=1.0 - decay)
else:
dst.copy_(t)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print( print(
f"device={device} model={args.model_type} hidden={args.hidden} " f"device={device} model={args.model_type} hidden={args.hidden} "
@@ -1180,9 +1225,11 @@ def main() -> int:
f"graph_readout={args.graph_readout} torsion_head={torsion_head_mode}" f"graph_readout={args.graph_readout} torsion_head={torsion_head_mode}"
) )
if torsion_head_mode == "bond_pair": if torsion_head_mode == "bond_pair":
sub_l = args.gcn_layers if int(args.torsion_sub_gcn_layers) < 0 else int(args.torsion_sub_gcn_layers)
print( print(
f"bond_pair_stab: subgraph_lr_scale={args.subgraph_lr_scale} " f"bond_pair_stab: subgraph_lr_scale={args.subgraph_lr_scale} "
f"subgraph_torsion_clip={args.subgraph_torsion_clip}" f"subgraph_torsion_clip={args.subgraph_torsion_clip} "
f"torsion_sub_gcn_layers={sub_l}"
) )
if args.rotation_loss == "hybrid": if args.rotation_loss == "hybrid":
print(f"rotation_hybrid_alpha={args.rotation_hybrid_alpha}") print(f"rotation_hybrid_alpha={args.rotation_hybrid_alpha}")
@@ -1198,6 +1245,8 @@ def main() -> int:
f"check_every={args.early_stop_check_every} warmup={args.early_stop_warmup}" f"check_every={args.early_stop_check_every} warmup={args.early_stop_warmup}"
) )
print(f"omega_max_norm={args.omega_max_norm}") print(f"omega_max_norm={args.omega_max_norm}")
if ema_state is not None:
print(f"ema_decay={ema_decay} (best checkpoint + final eval use EMA weights)")
nw = args.num_workers nw = args.num_workers
if nw < 0: if nw < 0:
@@ -1336,13 +1385,18 @@ def main() -> int:
if args.grad_clip > 0: if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step() optimizer.step()
if ema_state is not None:
ema_update_(ema_state, model, ema_decay)
avg_mse = sum_mse / accum avg_mse = sum_mse / accum
avg_mse_val = float(avg_mse.detach().cpu()) avg_mse_val = float(avg_mse.detach().cpu())
if avg_mse_val < best_train_mse: if avg_mse_val < best_train_mse:
best_train_mse = avg_mse_val best_train_mse = avg_mse_val
core = model._orig_mod if hasattr(model, "_orig_mod") else model core = model_core(model)
best_state = {k: v.detach().cpu().clone() for k, v in core.state_dict().items()} if ema_state is not None:
best_state = {k: v.detach().cpu().clone() for k, v in ema_state.items()}
else:
best_state = {k: v.detach().cpu().clone() for k, v in core.state_dict().items()}
should_check_early_stop = ( should_check_early_stop = (
args.early_stop_patience > 0 args.early_stop_patience > 0
@@ -1384,7 +1438,7 @@ def main() -> int:
ckpt_path = "" ckpt_path = ""
if best_state is not None: if best_state is not None:
core = model._orig_mod if hasattr(model, "_orig_mod") else model core = model_core(model)
core.load_state_dict(best_state) core.load_state_dict(best_state)
print(f"loaded best model from training (best_train_mse={best_train_mse:.6f})") print(f"loaded best model from training (best_train_mse={best_train_mse:.6f})")
artifacts_dir = os.path.join(here, "artifacts") artifacts_dir = os.path.join(here, "artifacts")
@@ -1402,6 +1456,7 @@ def main() -> int:
"sample_pickle_pos": int(args.pickle_pos), "sample_pickle_pos": int(args.pickle_pos),
"eval_runs": int(args.eval_runs), "eval_runs": int(args.eval_runs),
"rollout_steps": int(args.rollout_steps), "rollout_steps": int(args.rollout_steps),
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
} }
torch.save(ckpt, ckpt_path) torch.save(ckpt, ckpt_path)
print(f"saved best checkpoint: {ckpt_path}") print(f"saved best checkpoint: {ckpt_path}")
@@ -1435,6 +1490,7 @@ def main() -> int:
"checkpoint_path": ckpt_path, "checkpoint_path": ckpt_path,
"stop_reason": stop_reason, "stop_reason": stop_reason,
"stop_epoch": int(stop_epoch), "stop_epoch": int(stop_epoch),
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
} }
with open(latest_path, "w", encoding="utf-8") as f: with open(latest_path, "w", encoding="utf-8") as f:
json.dump(latest_report, f, indent=2) json.dump(latest_report, f, indent=2)