train: EMA option, bond_pair subgraph head, eval metadata; refresh best checkpoint.
Measured mean_rmsd_100=2.350750 (100 runs) on sample.sdf with residual geodesic config; improves main BEST_PRACTICE anchor 2.461592. Made-with: Cursor
This commit is contained in:
@@ -1,12 +1,13 @@
|
|||||||
{
|
{
|
||||||
"mean_rmsd_100": 2.620511382818222,
|
"mean_rmsd_100": 2.3507495498657227,
|
||||||
"num_runs": 100,
|
"num_runs": 100,
|
||||||
"timestamp_utc": "2026-04-16T15:40:21.752941+00:00",
|
"timestamp_utc": "2026-04-17T01:29:09.183572+00:00",
|
||||||
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 5.5e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --tail-risk-weight 0.0 --torsion-head bond_pair --subgraph-lr-scale 0.25 --subgraph-torsion-clip 6 --seed 1",
|
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --model-type gcn --hidden 512 --gcn-layers 8 --epochs 280 --batch-size 24 --lr 7e-4 --seed 1 --eval-runs 100 --gcn-residual --rotation-loss geodesic --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.8 --num-workers 4 --prefetch-factor 4 --out-dir /data/demian_dev/toy/ai_rfm/rfm_out_noema_geodesic_res280",
|
||||||
"notes": "Final test metric from 100 random-initialized rollouts to time=1.",
|
"notes": "Final test metric from 100 random-initialized rollouts to time=1.",
|
||||||
"best_train_mse": 3.1972899436950684,
|
"best_train_mse": 7.664132118225098,
|
||||||
"model_source": "best_train_checkpoint",
|
"model_source": "best_train_checkpoint",
|
||||||
"checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/latest_eval_best_model.pt",
|
"checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/latest_eval_best_model.pt",
|
||||||
"stop_reason": "max_epochs",
|
"stop_reason": "max_epochs",
|
||||||
"stop_epoch": 319
|
"stop_epoch": 279,
|
||||||
|
"ema_decay": 0.0
|
||||||
}
|
}
|
||||||
70
train.py
70
train.py
@@ -518,6 +518,7 @@ class RFMModel(nn.Module):
|
|||||||
torsion_head_mode: str = "global",
|
torsion_head_mode: str = "global",
|
||||||
torsion_subgraphs: nn.ModuleList | None = None,
|
torsion_subgraphs: nn.ModuleList | None = None,
|
||||||
torsion_subgraph_output_clip: float = 0.0,
|
torsion_subgraph_output_clip: float = 0.0,
|
||||||
|
torsion_sub_gcn_layers: int | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
@@ -532,6 +533,7 @@ class RFMModel(nn.Module):
|
|||||||
self.graph_readout = graph_readout
|
self.graph_readout = graph_readout
|
||||||
self.torsion_head_mode = torsion_head_mode
|
self.torsion_head_mode = torsion_head_mode
|
||||||
self.torsion_subgraph_output_clip = float(max(0.0, torsion_subgraph_output_clip))
|
self.torsion_subgraph_output_clip = float(max(0.0, torsion_subgraph_output_clip))
|
||||||
|
self.sub_gcn_skip: nn.ModuleList | None = None
|
||||||
|
|
||||||
if model_type == "gcn":
|
if model_type == "gcn":
|
||||||
self.convs = nn.ModuleList()
|
self.convs = nn.ModuleList()
|
||||||
@@ -586,6 +588,7 @@ class RFMModel(nn.Module):
|
|||||||
self.torsion_subgraphs = None
|
self.torsion_subgraphs = None
|
||||||
self.sub_convs = None
|
self.sub_convs = None
|
||||||
self.sub_pool_norm = None
|
self.sub_pool_norm = None
|
||||||
|
self.sub_gcn_skip = None
|
||||||
elif torsion_head_mode == "bond_pair":
|
elif torsion_head_mode == "bond_pair":
|
||||||
if model_type != "gcn":
|
if model_type != "gcn":
|
||||||
raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn")
|
raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn")
|
||||||
@@ -594,10 +597,14 @@ class RFMModel(nn.Module):
|
|||||||
# Dedicated subgraph encoder (weights not shared with full-graph convs).
|
# Dedicated subgraph encoder (weights not shared with full-graph convs).
|
||||||
# First-layer input = coords ⊕ full-graph node embedding (same trunk forward); avoids
|
# First-layer input = coords ⊕ full-graph node embedding (same trunk forward); avoids
|
||||||
# throwing away trunk geometry context when re-encoding the movable fragment only.
|
# throwing away trunk geometry context when re-encoding the movable fragment only.
|
||||||
|
ns = int(num_layers if torsion_sub_gcn_layers is None else torsion_sub_gcn_layers)
|
||||||
|
ns = max(1, min(ns, num_layers))
|
||||||
|
self.sub_num_layers = ns
|
||||||
self.sub_convs = nn.ModuleList()
|
self.sub_convs = nn.ModuleList()
|
||||||
self.sub_convs.append(GCNConv(3 + hidden, hidden))
|
self.sub_convs.append(GCNConv(3 + hidden, hidden))
|
||||||
for _ in range(num_layers - 1):
|
for _ in range(ns - 1):
|
||||||
self.sub_convs.append(GCNConv(hidden, hidden))
|
self.sub_convs.append(GCNConv(hidden, hidden))
|
||||||
|
self.sub_gcn_skip = nn.ModuleList([nn.Identity() for _ in range(ns)])
|
||||||
# Global pooled context + mean pool over subgraph GCN.
|
# Global pooled context + mean pool over subgraph GCN.
|
||||||
pair_in = d + hidden
|
pair_in = d + hidden
|
||||||
ph = max(64, min(256, hidden * 2))
|
ph = max(64, min(256, hidden * 2))
|
||||||
@@ -653,10 +660,11 @@ class RFMModel(nn.Module):
|
|||||||
h = batch_sub.x
|
h = batch_sub.x
|
||||||
edge_index, _ = add_self_loops(batch_sub.edge_index, num_nodes=batch_sub.num_nodes)
|
edge_index, _ = add_self_loops(batch_sub.edge_index, num_nodes=batch_sub.num_nodes)
|
||||||
batch_vec = batch_sub.batch
|
batch_vec = batch_sub.batch
|
||||||
|
assert self.sub_gcn_skip is not None
|
||||||
for idx, conv in enumerate(self.sub_convs):
|
for idx, conv in enumerate(self.sub_convs):
|
||||||
h_new = self.act(conv(h, edge_index))
|
h_new = self.act(conv(h, edge_index))
|
||||||
if self.gcn_residual and h_new.shape == h.shape:
|
if self.gcn_residual and h_new.shape == h.shape:
|
||||||
h = h_new + self.gcn_skip[idx](h)
|
h = h_new + self.sub_gcn_skip[idx](h)
|
||||||
else:
|
else:
|
||||||
h = h_new
|
h = h_new
|
||||||
sub_pooled = global_mean_pool(h, batch_vec)
|
sub_pooled = global_mean_pool(h, batch_vec)
|
||||||
@@ -1044,6 +1052,12 @@ def main() -> int:
|
|||||||
default=6.0,
|
default=6.0,
|
||||||
help="bond_pair: clamp predicted torsion channel to [-c, c]; use 0 to disable.",
|
help="bond_pair: clamp predicted torsion channel to [-c, c]; use 0 to disable.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--torsion-sub-gcn-layers",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="bond_pair: subgraph GCN depth (-1 = same as --gcn-layers).",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--early-stop-patience",
|
"--early-stop-patience",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -1074,6 +1088,15 @@ def main() -> int:
|
|||||||
default=0.0,
|
default=0.0,
|
||||||
help="If >0, clip predicted rotation omega vector norm for stability.",
|
help="If >0, clip predicted rotation omega vector norm for stability.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ema-decay",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help=(
|
||||||
|
"If in (0,1), maintain an EMA shadow of model weights after each optimizer step; "
|
||||||
|
"best_train checkpoint and final eval load the EMA weights (training uses online weights)."
|
||||||
|
),
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
@@ -1119,6 +1142,8 @@ def main() -> int:
|
|||||||
if torsion_head_mode == "bond_pair":
|
if torsion_head_mode == "bond_pair":
|
||||||
torsion_subgraphs = build_torsion_subgraph_specs(sample)
|
torsion_subgraphs = build_torsion_subgraph_specs(sample)
|
||||||
|
|
||||||
|
torsion_sub_gcn_layers = None if int(args.torsion_sub_gcn_layers) < 0 else int(args.torsion_sub_gcn_layers)
|
||||||
|
|
||||||
model = RFMModel(
|
model = RFMModel(
|
||||||
model_type=args.model_type,
|
model_type=args.model_type,
|
||||||
num_nodes=int(sample.num_nodes),
|
num_nodes=int(sample.num_nodes),
|
||||||
@@ -1132,13 +1157,14 @@ def main() -> int:
|
|||||||
torsion_head_mode=torsion_head_mode,
|
torsion_head_mode=torsion_head_mode,
|
||||||
torsion_subgraphs=torsion_subgraphs,
|
torsion_subgraphs=torsion_subgraphs,
|
||||||
torsion_subgraph_output_clip=float(args.subgraph_torsion_clip),
|
torsion_subgraph_output_clip=float(args.subgraph_torsion_clip),
|
||||||
|
torsion_sub_gcn_layers=torsion_sub_gcn_layers,
|
||||||
).to(device)
|
).to(device)
|
||||||
subgraph_param_ids: set[int] = set()
|
subgraph_param_ids: set[int] = set()
|
||||||
if torsion_head_mode == "bond_pair":
|
if torsion_head_mode == "bond_pair":
|
||||||
for name, p in model.named_parameters():
|
for name, p in model.named_parameters():
|
||||||
if not p.requires_grad:
|
if not p.requires_grad:
|
||||||
continue
|
continue
|
||||||
if name.startswith(("sub_convs", "torsion_pair", "sub_pool_norm")):
|
if name.startswith(("sub_convs", "sub_gcn_skip", "torsion_pair", "sub_pool_norm")):
|
||||||
subgraph_param_ids.add(id(p))
|
subgraph_param_ids.add(id(p))
|
||||||
if args.compile and hasattr(torch, "compile"):
|
if args.compile and hasattr(torch, "compile"):
|
||||||
model = torch.compile(model) # type: ignore[assignment]
|
model = torch.compile(model) # type: ignore[assignment]
|
||||||
@@ -1161,6 +1187,25 @@ def main() -> int:
|
|||||||
else:
|
else:
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||||
|
|
||||||
|
def model_core(m: nn.Module) -> nn.Module:
|
||||||
|
return m._orig_mod if hasattr(m, "_orig_mod") else m
|
||||||
|
|
||||||
|
ema_decay = float(args.ema_decay)
|
||||||
|
ema_state: dict[str, torch.Tensor] | None = None
|
||||||
|
if 0.0 < ema_decay < 1.0:
|
||||||
|
core0 = model_core(model)
|
||||||
|
ema_state = {k: v.detach().clone() for k, v in core0.state_dict().items()}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ema_update_(shadow: dict[str, torch.Tensor], m: nn.Module, decay: float) -> None:
|
||||||
|
src = model_core(m).state_dict()
|
||||||
|
for k, dst in shadow.items():
|
||||||
|
t = src[k]
|
||||||
|
if torch.is_floating_point(t):
|
||||||
|
dst.mul_(decay).add_(t, alpha=1.0 - decay)
|
||||||
|
else:
|
||||||
|
dst.copy_(t)
|
||||||
|
|
||||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
print(
|
print(
|
||||||
f"device={device} model={args.model_type} hidden={args.hidden} "
|
f"device={device} model={args.model_type} hidden={args.hidden} "
|
||||||
@@ -1180,9 +1225,11 @@ def main() -> int:
|
|||||||
f"graph_readout={args.graph_readout} torsion_head={torsion_head_mode}"
|
f"graph_readout={args.graph_readout} torsion_head={torsion_head_mode}"
|
||||||
)
|
)
|
||||||
if torsion_head_mode == "bond_pair":
|
if torsion_head_mode == "bond_pair":
|
||||||
|
sub_l = args.gcn_layers if int(args.torsion_sub_gcn_layers) < 0 else int(args.torsion_sub_gcn_layers)
|
||||||
print(
|
print(
|
||||||
f"bond_pair_stab: subgraph_lr_scale={args.subgraph_lr_scale} "
|
f"bond_pair_stab: subgraph_lr_scale={args.subgraph_lr_scale} "
|
||||||
f"subgraph_torsion_clip={args.subgraph_torsion_clip}"
|
f"subgraph_torsion_clip={args.subgraph_torsion_clip} "
|
||||||
|
f"torsion_sub_gcn_layers={sub_l}"
|
||||||
)
|
)
|
||||||
if args.rotation_loss == "hybrid":
|
if args.rotation_loss == "hybrid":
|
||||||
print(f"rotation_hybrid_alpha={args.rotation_hybrid_alpha}")
|
print(f"rotation_hybrid_alpha={args.rotation_hybrid_alpha}")
|
||||||
@@ -1198,6 +1245,8 @@ def main() -> int:
|
|||||||
f"check_every={args.early_stop_check_every} warmup={args.early_stop_warmup}"
|
f"check_every={args.early_stop_check_every} warmup={args.early_stop_warmup}"
|
||||||
)
|
)
|
||||||
print(f"omega_max_norm={args.omega_max_norm}")
|
print(f"omega_max_norm={args.omega_max_norm}")
|
||||||
|
if ema_state is not None:
|
||||||
|
print(f"ema_decay={ema_decay} (best checkpoint + final eval use EMA weights)")
|
||||||
|
|
||||||
nw = args.num_workers
|
nw = args.num_workers
|
||||||
if nw < 0:
|
if nw < 0:
|
||||||
@@ -1336,13 +1385,18 @@ def main() -> int:
|
|||||||
if args.grad_clip > 0:
|
if args.grad_clip > 0:
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
if ema_state is not None:
|
||||||
|
ema_update_(ema_state, model, ema_decay)
|
||||||
|
|
||||||
avg_mse = sum_mse / accum
|
avg_mse = sum_mse / accum
|
||||||
avg_mse_val = float(avg_mse.detach().cpu())
|
avg_mse_val = float(avg_mse.detach().cpu())
|
||||||
if avg_mse_val < best_train_mse:
|
if avg_mse_val < best_train_mse:
|
||||||
best_train_mse = avg_mse_val
|
best_train_mse = avg_mse_val
|
||||||
core = model._orig_mod if hasattr(model, "_orig_mod") else model
|
core = model_core(model)
|
||||||
best_state = {k: v.detach().cpu().clone() for k, v in core.state_dict().items()}
|
if ema_state is not None:
|
||||||
|
best_state = {k: v.detach().cpu().clone() for k, v in ema_state.items()}
|
||||||
|
else:
|
||||||
|
best_state = {k: v.detach().cpu().clone() for k, v in core.state_dict().items()}
|
||||||
|
|
||||||
should_check_early_stop = (
|
should_check_early_stop = (
|
||||||
args.early_stop_patience > 0
|
args.early_stop_patience > 0
|
||||||
@@ -1384,7 +1438,7 @@ def main() -> int:
|
|||||||
|
|
||||||
ckpt_path = ""
|
ckpt_path = ""
|
||||||
if best_state is not None:
|
if best_state is not None:
|
||||||
core = model._orig_mod if hasattr(model, "_orig_mod") else model
|
core = model_core(model)
|
||||||
core.load_state_dict(best_state)
|
core.load_state_dict(best_state)
|
||||||
print(f"loaded best model from training (best_train_mse={best_train_mse:.6f})")
|
print(f"loaded best model from training (best_train_mse={best_train_mse:.6f})")
|
||||||
artifacts_dir = os.path.join(here, "artifacts")
|
artifacts_dir = os.path.join(here, "artifacts")
|
||||||
@@ -1402,6 +1456,7 @@ def main() -> int:
|
|||||||
"sample_pickle_pos": int(args.pickle_pos),
|
"sample_pickle_pos": int(args.pickle_pos),
|
||||||
"eval_runs": int(args.eval_runs),
|
"eval_runs": int(args.eval_runs),
|
||||||
"rollout_steps": int(args.rollout_steps),
|
"rollout_steps": int(args.rollout_steps),
|
||||||
|
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
|
||||||
}
|
}
|
||||||
torch.save(ckpt, ckpt_path)
|
torch.save(ckpt, ckpt_path)
|
||||||
print(f"saved best checkpoint: {ckpt_path}")
|
print(f"saved best checkpoint: {ckpt_path}")
|
||||||
@@ -1435,6 +1490,7 @@ def main() -> int:
|
|||||||
"checkpoint_path": ckpt_path,
|
"checkpoint_path": ckpt_path,
|
||||||
"stop_reason": stop_reason,
|
"stop_reason": stop_reason,
|
||||||
"stop_epoch": int(stop_epoch),
|
"stop_epoch": int(stop_epoch),
|
||||||
|
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
|
||||||
}
|
}
|
||||||
with open(latest_path, "w", encoding="utf-8") as f:
|
with open(latest_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(latest_report, f, indent=2)
|
json.dump(latest_report, f, indent=2)
|
||||||
|
|||||||
Reference in New Issue
Block a user