diff --git a/reports/latest_eval.json b/reports/latest_eval.json index d2699e8..cc3094e 100644 --- a/reports/latest_eval.json +++ b/reports/latest_eval.json @@ -1,12 +1,13 @@ { - "mean_rmsd_100": 2.620511382818222, + "mean_rmsd_100": 2.3507495498657227, "num_runs": 100, - "timestamp_utc": "2026-04-16T15:40:21.752941+00:00", - "command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 5.5e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --tail-risk-weight 0.0 --torsion-head bond_pair --subgraph-lr-scale 0.25 --subgraph-torsion-clip 6 --seed 1", + "timestamp_utc": "2026-04-17T01:29:09.183572+00:00", + "command": "train.py --sdf /data/demian_dev/toy/sample.sdf --model-type gcn --hidden 512 --gcn-layers 8 --epochs 280 --batch-size 24 --lr 7e-4 --seed 1 --eval-runs 100 --gcn-residual --rotation-loss geodesic --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.8 --num-workers 4 --prefetch-factor 4 --out-dir /data/demian_dev/toy/ai_rfm/rfm_out_noema_geodesic_res280", "notes": "Final test metric from 100 random-initialized rollouts to time=1.", - "best_train_mse": 3.1972899436950684, + "best_train_mse": 7.664132118225098, "model_source": "best_train_checkpoint", "checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/latest_eval_best_model.pt", "stop_reason": "max_epochs", - "stop_epoch": 319 + "stop_epoch": 279, + "ema_decay": 0.0 } \ No newline at end of file diff --git a/train.py b/train.py index 294d7ef..d0deffa 100644 --- a/train.py +++ b/train.py @@ -518,6 +518,7 @@ class RFMModel(nn.Module): torsion_head_mode: str = "global", torsion_subgraphs: nn.ModuleList | None = None, torsion_subgraph_output_clip: float = 0.0, + torsion_sub_gcn_layers: int | None = None, ): super().__init__() self.model_type = model_type @@ -532,6 +533,7 @@ class RFMModel(nn.Module): self.graph_readout = graph_readout self.torsion_head_mode = torsion_head_mode self.torsion_subgraph_output_clip = float(max(0.0, torsion_subgraph_output_clip)) + self.sub_gcn_skip: nn.ModuleList | None = None if model_type == "gcn": self.convs = nn.ModuleList() @@ -586,6 +588,7 @@ class RFMModel(nn.Module): self.torsion_subgraphs = None self.sub_convs = None self.sub_pool_norm = None + self.sub_gcn_skip = None elif torsion_head_mode == "bond_pair": if model_type != "gcn": raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn") @@ -594,10 +597,14 @@ class RFMModel(nn.Module): # Dedicated subgraph encoder (weights not shared with full-graph convs). # First-layer input = coords ⊕ full-graph node embedding (same trunk forward); avoids # throwing away trunk geometry context when re-encoding the movable fragment only. + ns = int(num_layers if torsion_sub_gcn_layers is None else torsion_sub_gcn_layers) + ns = max(1, min(ns, num_layers)) + self.sub_num_layers = ns self.sub_convs = nn.ModuleList() self.sub_convs.append(GCNConv(3 + hidden, hidden)) - for _ in range(num_layers - 1): + for _ in range(ns - 1): self.sub_convs.append(GCNConv(hidden, hidden)) + self.sub_gcn_skip = nn.ModuleList([nn.Identity() for _ in range(ns)]) # Global pooled context + mean pool over subgraph GCN. pair_in = d + hidden ph = max(64, min(256, hidden * 2)) @@ -653,10 +660,11 @@ class RFMModel(nn.Module): h = batch_sub.x edge_index, _ = add_self_loops(batch_sub.edge_index, num_nodes=batch_sub.num_nodes) batch_vec = batch_sub.batch + assert self.sub_gcn_skip is not None for idx, conv in enumerate(self.sub_convs): h_new = self.act(conv(h, edge_index)) if self.gcn_residual and h_new.shape == h.shape: - h = h_new + self.gcn_skip[idx](h) + h = h_new + self.sub_gcn_skip[idx](h) else: h = h_new sub_pooled = global_mean_pool(h, batch_vec) @@ -1044,6 +1052,12 @@ def main() -> int: default=6.0, help="bond_pair: clamp predicted torsion channel to [-c, c]; use 0 to disable.", ) + parser.add_argument( + "--torsion-sub-gcn-layers", + type=int, + default=-1, + help="bond_pair: subgraph GCN depth (-1 = same as --gcn-layers).", + ) parser.add_argument( "--early-stop-patience", type=int, @@ -1074,6 +1088,15 @@ def main() -> int: default=0.0, help="If >0, clip predicted rotation omega vector norm for stability.", ) + parser.add_argument( + "--ema-decay", + type=float, + default=0.0, + help=( + "If in (0,1), maintain an EMA shadow of model weights after each optimizer step; " + "best_train checkpoint and final eval load the EMA weights (training uses online weights)." + ), + ) args = parser.parse_args() torch.manual_seed(args.seed) @@ -1119,6 +1142,8 @@ def main() -> int: if torsion_head_mode == "bond_pair": torsion_subgraphs = build_torsion_subgraph_specs(sample) + torsion_sub_gcn_layers = None if int(args.torsion_sub_gcn_layers) < 0 else int(args.torsion_sub_gcn_layers) + model = RFMModel( model_type=args.model_type, num_nodes=int(sample.num_nodes), @@ -1132,13 +1157,14 @@ def main() -> int: torsion_head_mode=torsion_head_mode, torsion_subgraphs=torsion_subgraphs, torsion_subgraph_output_clip=float(args.subgraph_torsion_clip), + torsion_sub_gcn_layers=torsion_sub_gcn_layers, ).to(device) subgraph_param_ids: set[int] = set() if torsion_head_mode == "bond_pair": for name, p in model.named_parameters(): if not p.requires_grad: continue - if name.startswith(("sub_convs", "torsion_pair", "sub_pool_norm")): + if name.startswith(("sub_convs", "sub_gcn_skip", "torsion_pair", "sub_pool_norm")): subgraph_param_ids.add(id(p)) if args.compile and hasattr(torch, "compile"): model = torch.compile(model) # type: ignore[assignment] @@ -1161,6 +1187,25 @@ def main() -> int: else: optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + def model_core(m: nn.Module) -> nn.Module: + return m._orig_mod if hasattr(m, "_orig_mod") else m + + ema_decay = float(args.ema_decay) + ema_state: dict[str, torch.Tensor] | None = None + if 0.0 < ema_decay < 1.0: + core0 = model_core(model) + ema_state = {k: v.detach().clone() for k, v in core0.state_dict().items()} + + @torch.no_grad() + def ema_update_(shadow: dict[str, torch.Tensor], m: nn.Module, decay: float) -> None: + src = model_core(m).state_dict() + for k, dst in shadow.items(): + t = src[k] + if torch.is_floating_point(t): + dst.mul_(decay).add_(t, alpha=1.0 - decay) + else: + dst.copy_(t) + n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print( f"device={device} model={args.model_type} hidden={args.hidden} " @@ -1180,9 +1225,11 @@ def main() -> int: f"graph_readout={args.graph_readout} torsion_head={torsion_head_mode}" ) if torsion_head_mode == "bond_pair": + sub_l = args.gcn_layers if int(args.torsion_sub_gcn_layers) < 0 else int(args.torsion_sub_gcn_layers) print( f"bond_pair_stab: subgraph_lr_scale={args.subgraph_lr_scale} " - f"subgraph_torsion_clip={args.subgraph_torsion_clip}" + f"subgraph_torsion_clip={args.subgraph_torsion_clip} " + f"torsion_sub_gcn_layers={sub_l}" ) if args.rotation_loss == "hybrid": print(f"rotation_hybrid_alpha={args.rotation_hybrid_alpha}") @@ -1198,6 +1245,8 @@ def main() -> int: f"check_every={args.early_stop_check_every} warmup={args.early_stop_warmup}" ) print(f"omega_max_norm={args.omega_max_norm}") + if ema_state is not None: + print(f"ema_decay={ema_decay} (best checkpoint + final eval use EMA weights)") nw = args.num_workers if nw < 0: @@ -1336,13 +1385,18 @@ def main() -> int: if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() + if ema_state is not None: + ema_update_(ema_state, model, ema_decay) avg_mse = sum_mse / accum avg_mse_val = float(avg_mse.detach().cpu()) if avg_mse_val < best_train_mse: best_train_mse = avg_mse_val - core = model._orig_mod if hasattr(model, "_orig_mod") else model - best_state = {k: v.detach().cpu().clone() for k, v in core.state_dict().items()} + core = model_core(model) + if ema_state is not None: + best_state = {k: v.detach().cpu().clone() for k, v in ema_state.items()} + else: + best_state = {k: v.detach().cpu().clone() for k, v in core.state_dict().items()} should_check_early_stop = ( args.early_stop_patience > 0 @@ -1384,7 +1438,7 @@ def main() -> int: ckpt_path = "" if best_state is not None: - core = model._orig_mod if hasattr(model, "_orig_mod") else model + core = model_core(model) core.load_state_dict(best_state) print(f"loaded best model from training (best_train_mse={best_train_mse:.6f})") artifacts_dir = os.path.join(here, "artifacts") @@ -1402,6 +1456,7 @@ def main() -> int: "sample_pickle_pos": int(args.pickle_pos), "eval_runs": int(args.eval_runs), "rollout_steps": int(args.rollout_steps), + "ema_decay": float(ema_decay) if ema_state is not None else 0.0, } torch.save(ckpt, ckpt_path) print(f"saved best checkpoint: {ckpt_path}") @@ -1435,6 +1490,7 @@ def main() -> int: "checkpoint_path": ckpt_path, "stop_reason": stop_reason, "stop_epoch": int(stop_epoch), + "ema_decay": float(ema_decay) if ema_state is not None else 0.0, } with open(latest_path, "w", encoding="utf-8") as f: json.dump(latest_report, f, indent=2)