train: EMA option, bond_pair subgraph head, eval metadata; refresh best checkpoint.
Measured mean_rmsd_100=2.350750 (100 runs) on sample.sdf with residual geodesic config; improves main BEST_PRACTICE anchor 2.461592. Made-with: Cursor
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
{
|
||||
"mean_rmsd_100": 2.620511382818222,
|
||||
"mean_rmsd_100": 2.3507495498657227,
|
||||
"num_runs": 100,
|
||||
"timestamp_utc": "2026-04-16T15:40:21.752941+00:00",
|
||||
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 5.5e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --tail-risk-weight 0.0 --torsion-head bond_pair --subgraph-lr-scale 0.25 --subgraph-torsion-clip 6 --seed 1",
|
||||
"timestamp_utc": "2026-04-17T01:29:09.183572+00:00",
|
||||
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --model-type gcn --hidden 512 --gcn-layers 8 --epochs 280 --batch-size 24 --lr 7e-4 --seed 1 --eval-runs 100 --gcn-residual --rotation-loss geodesic --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.8 --num-workers 4 --prefetch-factor 4 --out-dir /data/demian_dev/toy/ai_rfm/rfm_out_noema_geodesic_res280",
|
||||
"notes": "Final test metric from 100 random-initialized rollouts to time=1.",
|
||||
"best_train_mse": 3.1972899436950684,
|
||||
"best_train_mse": 7.664132118225098,
|
||||
"model_source": "best_train_checkpoint",
|
||||
"checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/latest_eval_best_model.pt",
|
||||
"stop_reason": "max_epochs",
|
||||
"stop_epoch": 319
|
||||
"stop_epoch": 279,
|
||||
"ema_decay": 0.0
|
||||
}
|
||||
66
train.py
66
train.py
@@ -518,6 +518,7 @@ class RFMModel(nn.Module):
|
||||
torsion_head_mode: str = "global",
|
||||
torsion_subgraphs: nn.ModuleList | None = None,
|
||||
torsion_subgraph_output_clip: float = 0.0,
|
||||
torsion_sub_gcn_layers: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_type = model_type
|
||||
@@ -532,6 +533,7 @@ class RFMModel(nn.Module):
|
||||
self.graph_readout = graph_readout
|
||||
self.torsion_head_mode = torsion_head_mode
|
||||
self.torsion_subgraph_output_clip = float(max(0.0, torsion_subgraph_output_clip))
|
||||
self.sub_gcn_skip: nn.ModuleList | None = None
|
||||
|
||||
if model_type == "gcn":
|
||||
self.convs = nn.ModuleList()
|
||||
@@ -586,6 +588,7 @@ class RFMModel(nn.Module):
|
||||
self.torsion_subgraphs = None
|
||||
self.sub_convs = None
|
||||
self.sub_pool_norm = None
|
||||
self.sub_gcn_skip = None
|
||||
elif torsion_head_mode == "bond_pair":
|
||||
if model_type != "gcn":
|
||||
raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn")
|
||||
@@ -594,10 +597,14 @@ class RFMModel(nn.Module):
|
||||
# Dedicated subgraph encoder (weights not shared with full-graph convs).
|
||||
# First-layer input = coords ⊕ full-graph node embedding (same trunk forward); avoids
|
||||
# throwing away trunk geometry context when re-encoding the movable fragment only.
|
||||
ns = int(num_layers if torsion_sub_gcn_layers is None else torsion_sub_gcn_layers)
|
||||
ns = max(1, min(ns, num_layers))
|
||||
self.sub_num_layers = ns
|
||||
self.sub_convs = nn.ModuleList()
|
||||
self.sub_convs.append(GCNConv(3 + hidden, hidden))
|
||||
for _ in range(num_layers - 1):
|
||||
for _ in range(ns - 1):
|
||||
self.sub_convs.append(GCNConv(hidden, hidden))
|
||||
self.sub_gcn_skip = nn.ModuleList([nn.Identity() for _ in range(ns)])
|
||||
# Global pooled context + mean pool over subgraph GCN.
|
||||
pair_in = d + hidden
|
||||
ph = max(64, min(256, hidden * 2))
|
||||
@@ -653,10 +660,11 @@ class RFMModel(nn.Module):
|
||||
h = batch_sub.x
|
||||
edge_index, _ = add_self_loops(batch_sub.edge_index, num_nodes=batch_sub.num_nodes)
|
||||
batch_vec = batch_sub.batch
|
||||
assert self.sub_gcn_skip is not None
|
||||
for idx, conv in enumerate(self.sub_convs):
|
||||
h_new = self.act(conv(h, edge_index))
|
||||
if self.gcn_residual and h_new.shape == h.shape:
|
||||
h = h_new + self.gcn_skip[idx](h)
|
||||
h = h_new + self.sub_gcn_skip[idx](h)
|
||||
else:
|
||||
h = h_new
|
||||
sub_pooled = global_mean_pool(h, batch_vec)
|
||||
@@ -1044,6 +1052,12 @@ def main() -> int:
|
||||
default=6.0,
|
||||
help="bond_pair: clamp predicted torsion channel to [-c, c]; use 0 to disable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torsion-sub-gcn-layers",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="bond_pair: subgraph GCN depth (-1 = same as --gcn-layers).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early-stop-patience",
|
||||
type=int,
|
||||
@@ -1074,6 +1088,15 @@ def main() -> int:
|
||||
default=0.0,
|
||||
help="If >0, clip predicted rotation omega vector norm for stability.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ema-decay",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help=(
|
||||
"If in (0,1), maintain an EMA shadow of model weights after each optimizer step; "
|
||||
"best_train checkpoint and final eval load the EMA weights (training uses online weights)."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
@@ -1119,6 +1142,8 @@ def main() -> int:
|
||||
if torsion_head_mode == "bond_pair":
|
||||
torsion_subgraphs = build_torsion_subgraph_specs(sample)
|
||||
|
||||
torsion_sub_gcn_layers = None if int(args.torsion_sub_gcn_layers) < 0 else int(args.torsion_sub_gcn_layers)
|
||||
|
||||
model = RFMModel(
|
||||
model_type=args.model_type,
|
||||
num_nodes=int(sample.num_nodes),
|
||||
@@ -1132,13 +1157,14 @@ def main() -> int:
|
||||
torsion_head_mode=torsion_head_mode,
|
||||
torsion_subgraphs=torsion_subgraphs,
|
||||
torsion_subgraph_output_clip=float(args.subgraph_torsion_clip),
|
||||
torsion_sub_gcn_layers=torsion_sub_gcn_layers,
|
||||
).to(device)
|
||||
subgraph_param_ids: set[int] = set()
|
||||
if torsion_head_mode == "bond_pair":
|
||||
for name, p in model.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if name.startswith(("sub_convs", "torsion_pair", "sub_pool_norm")):
|
||||
if name.startswith(("sub_convs", "sub_gcn_skip", "torsion_pair", "sub_pool_norm")):
|
||||
subgraph_param_ids.add(id(p))
|
||||
if args.compile and hasattr(torch, "compile"):
|
||||
model = torch.compile(model) # type: ignore[assignment]
|
||||
@@ -1161,6 +1187,25 @@ def main() -> int:
|
||||
else:
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
def model_core(m: nn.Module) -> nn.Module:
|
||||
return m._orig_mod if hasattr(m, "_orig_mod") else m
|
||||
|
||||
ema_decay = float(args.ema_decay)
|
||||
ema_state: dict[str, torch.Tensor] | None = None
|
||||
if 0.0 < ema_decay < 1.0:
|
||||
core0 = model_core(model)
|
||||
ema_state = {k: v.detach().clone() for k, v in core0.state_dict().items()}
|
||||
|
||||
@torch.no_grad()
|
||||
def ema_update_(shadow: dict[str, torch.Tensor], m: nn.Module, decay: float) -> None:
|
||||
src = model_core(m).state_dict()
|
||||
for k, dst in shadow.items():
|
||||
t = src[k]
|
||||
if torch.is_floating_point(t):
|
||||
dst.mul_(decay).add_(t, alpha=1.0 - decay)
|
||||
else:
|
||||
dst.copy_(t)
|
||||
|
||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(
|
||||
f"device={device} model={args.model_type} hidden={args.hidden} "
|
||||
@@ -1180,9 +1225,11 @@ def main() -> int:
|
||||
f"graph_readout={args.graph_readout} torsion_head={torsion_head_mode}"
|
||||
)
|
||||
if torsion_head_mode == "bond_pair":
|
||||
sub_l = args.gcn_layers if int(args.torsion_sub_gcn_layers) < 0 else int(args.torsion_sub_gcn_layers)
|
||||
print(
|
||||
f"bond_pair_stab: subgraph_lr_scale={args.subgraph_lr_scale} "
|
||||
f"subgraph_torsion_clip={args.subgraph_torsion_clip} "
|
||||
f"torsion_sub_gcn_layers={sub_l}"
|
||||
)
|
||||
if args.rotation_loss == "hybrid":
|
||||
print(f"rotation_hybrid_alpha={args.rotation_hybrid_alpha}")
|
||||
@@ -1198,6 +1245,8 @@ def main() -> int:
|
||||
f"check_every={args.early_stop_check_every} warmup={args.early_stop_warmup}"
|
||||
)
|
||||
print(f"omega_max_norm={args.omega_max_norm}")
|
||||
if ema_state is not None:
|
||||
print(f"ema_decay={ema_decay} (best checkpoint + final eval use EMA weights)")
|
||||
|
||||
nw = args.num_workers
|
||||
if nw < 0:
|
||||
@@ -1336,12 +1385,17 @@ def main() -> int:
|
||||
if args.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
optimizer.step()
|
||||
if ema_state is not None:
|
||||
ema_update_(ema_state, model, ema_decay)
|
||||
|
||||
avg_mse = sum_mse / accum
|
||||
avg_mse_val = float(avg_mse.detach().cpu())
|
||||
if avg_mse_val < best_train_mse:
|
||||
best_train_mse = avg_mse_val
|
||||
core = model._orig_mod if hasattr(model, "_orig_mod") else model
|
||||
core = model_core(model)
|
||||
if ema_state is not None:
|
||||
best_state = {k: v.detach().cpu().clone() for k, v in ema_state.items()}
|
||||
else:
|
||||
best_state = {k: v.detach().cpu().clone() for k, v in core.state_dict().items()}
|
||||
|
||||
should_check_early_stop = (
|
||||
@@ -1384,7 +1438,7 @@ def main() -> int:
|
||||
|
||||
ckpt_path = ""
|
||||
if best_state is not None:
|
||||
core = model._orig_mod if hasattr(model, "_orig_mod") else model
|
||||
core = model_core(model)
|
||||
core.load_state_dict(best_state)
|
||||
print(f"loaded best model from training (best_train_mse={best_train_mse:.6f})")
|
||||
artifacts_dir = os.path.join(here, "artifacts")
|
||||
@@ -1402,6 +1456,7 @@ def main() -> int:
|
||||
"sample_pickle_pos": int(args.pickle_pos),
|
||||
"eval_runs": int(args.eval_runs),
|
||||
"rollout_steps": int(args.rollout_steps),
|
||||
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
|
||||
}
|
||||
torch.save(ckpt, ckpt_path)
|
||||
print(f"saved best checkpoint: {ckpt_path}")
|
||||
@@ -1435,6 +1490,7 @@ def main() -> int:
|
||||
"checkpoint_path": ckpt_path,
|
||||
"stop_reason": stop_reason,
|
||||
"stop_epoch": int(stop_epoch),
|
||||
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
|
||||
}
|
||||
with open(latest_path, "w", encoding="utf-8") as f:
|
||||
json.dump(latest_report, f, indent=2)
|
||||
|
||||
Reference in New Issue
Block a user