Merge branch 'attempt/s3-restart-after-doc-sync' into post-main-doc-sync

This commit is contained in:
demian3b
2026-04-16 23:59:41 +09:00
2 changed files with 298 additions and 22 deletions

View File

@@ -1,10 +1,10 @@
{
"mean_rmsd_100": 2.464729991555214,
"mean_rmsd_100": 2.598530296087265,
"num_runs": 100,
"timestamp_utc": "2026-04-16T14:27:02.733292+00:00",
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 6.8e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --seed 1",
"timestamp_utc": "2026-04-16T14:56:28.000610+00:00",
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 6.8e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --tail-risk-weight 0.0 --torsion-head bond_pair --seed 1",
"notes": "Final test metric from 100 random-initialized rollouts to time=1.",
"best_train_mse": 3.5417799949645996,
"best_train_mse": 6.003782272338867,
"model_source": "best_train_checkpoint",
"checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/latest_eval_best_model.pt",
"stop_reason": "max_epochs",

312
train.py
View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import argparse
import json
import math
import os
import pickle
import sys
@@ -26,7 +27,8 @@ from rdkit.Chem import SDWriter
from rdkit.Geometry.rdGeometry import Point3D
from torch import nn
from torch_geometric.data import Batch, Data
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.nn import GCNConv, global_add_pool, global_mean_pool
from torch_geometric.utils import softmax
# ---------------------------------------------------------------------------
@@ -207,6 +209,19 @@ def torsion_wrapped_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.
return torch.mean(wrap_angle(pred - target) ** 2)
def upper_quantile_mean(values: torch.Tensor, quantile: float) -> torch.Tensor:
"""Mean of upper-quantile tail values."""
q = float(max(0.0, min(1.0, quantile)))
if q <= 0.0:
return values.mean()
if q >= 1.0:
return values.max()
n = int(values.numel())
k = max(1, int(math.ceil(n * (1.0 - q))))
top_vals, _ = torch.topk(values.reshape(-1), k=k, largest=True)
return top_vals.mean()
def so3_geodesic_mse_loss(pred_omega: torch.Tensor, target_omega: torch.Tensor) -> torch.Tensor:
"""Geodesic MSE between batched axis-angle increments."""
if pred_omega.ndim != 2 or pred_omega.shape[-1] != 3:
@@ -223,6 +238,14 @@ def so3_geodesic_mse_loss(pred_omega: torch.Tensor, target_omega: torch.Tensor)
return torch.stack(losses).mean()
def clip_vector_norm_batch(x: torch.Tensor, max_norm: float) -> torch.Tensor:
if max_norm <= 0:
return x
norms = torch.linalg.norm(x, dim=-1, keepdim=True).clamp(min=1e-9)
scale = torch.clamp(max_norm / norms, max=1.0)
return x * scale
def skew(w: torch.Tensor) -> torch.Tensor:
z = torch.zeros((), device=w.device, dtype=w.dtype)
return torch.stack(
@@ -365,6 +388,7 @@ def evaluate_mean_rmsd_100(
device: torch.device,
num_runs: int = 100,
rollout_steps: int = 100,
omega_max_norm: float = 0.0,
) -> float:
"""Run 100 one-shot predictions to time=1 and average final-structure RMSD."""
del rollout_steps
@@ -390,7 +414,7 @@ def evaluate_mean_rmsd_100(
with torch.no_grad():
v = model(batch, tb)
vc = v.center.detach().cpu()
vw = v.rot_omega.detach().cpu()
vw = clip_vector_norm_batch(v.rot_omega.detach().cpu(), omega_max_norm)
vt = v.torsion.detach().cpu()
center_t = center_t + vc
torsion_t = torsion_t + vt
@@ -407,6 +431,68 @@ def evaluate_mean_rmsd_100(
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class TorsionSubgraphSpec(nn.Module):
"""Movable-side induced subgraph for one torsion (full-graph node ids + local edge_index)."""
def __init__(self, nodes_full: torch.Tensor, edge_index_local: torch.Tensor):
super().__init__()
self.register_buffer("nodes_full", nodes_full, persistent=True)
self.register_buffer("edge_index_local", edge_index_local, persistent=True)
def build_torsion_subgraph_specs(sample: Data) -> nn.ModuleList:
"""Induced subgraph on atoms where movable_mask[k] is True for each torsion k."""
n = int(sample.num_nodes)
ei = sample.edge_index
if ei.device.type != "cpu":
ei = ei.cpu()
mv = sample.movable_mask
if mv.device.type != "cpu":
mv = mv.cpu()
t_count = int(sample.bond_endpoints.shape[0])
modules: list[TorsionSubgraphSpec] = []
for k in range(t_count):
mask_row = mv[k].bool()
nodes_full = torch.where(mask_row)[0].long()
if nodes_full.numel() == 0:
raise RuntimeError(f"torsion {k}: empty movable side; cannot build subgraph")
node_set = set(int(x) for x in nodes_full.tolist())
old_to_new = {old: i for i, old in enumerate(sorted(node_set))}
es: list[int] = []
ed: list[int] = []
for e in range(ei.shape[1]):
a, b = int(ei[0, e]), int(ei[1, e])
if a in node_set and b in node_set:
es.append(old_to_new[a])
ed.append(old_to_new[b])
if len(es) > 0:
edge_index_local = torch.tensor([es, ed], dtype=torch.long)
else:
nk = nodes_full.numel()
edge_index_local = torch.stack(
[torch.arange(nk, dtype=torch.long), torch.arange(nk, dtype=torch.long)], dim=0
)
modules.append(TorsionSubgraphSpec(nodes_full, edge_index_local))
return nn.ModuleList(modules)
class AttentionGraphReadout(nn.Module):
"""Learned softmax weights over nodes (distinct inductive bias vs mean pooling)."""
def __init__(self, node_dim: int, gate_hidden: int = 64):
super().__init__()
self.gate = nn.Sequential(
nn.Linear(node_dim, gate_hidden),
nn.SiLU(),
nn.Linear(gate_hidden, 1),
)
def forward(self, h: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
scores = self.gate(h).squeeze(-1)
alpha = softmax(scores, batch)
return global_add_pool(h * alpha.unsqueeze(-1), batch)
class RFMModel(nn.Module):
"""Model for velocity prediction. Supports GCN and fixed-order MLP encoders."""
@@ -420,6 +506,9 @@ class RFMModel(nn.Module):
gcn_residual: bool = False,
channel_layernorm: bool = False,
head_mlp_layers: int = 1,
graph_readout: str = "mean",
torsion_head_mode: str = "global",
torsion_subgraphs: nn.ModuleList | None = None,
):
super().__init__()
self.model_type = model_type
@@ -431,6 +520,8 @@ class RFMModel(nn.Module):
self.channel_layernorm = channel_layernorm
self.act = nn.SiLU()
self.head_mlp_layers = max(1, int(head_mlp_layers))
self.graph_readout = graph_readout
self.torsion_head_mode = torsion_head_mode
if model_type == "gcn":
self.convs = nn.ModuleList()
@@ -441,7 +532,14 @@ class RFMModel(nn.Module):
for _ in range(num_layers - 1):
self.gcn_skip.append(nn.Identity())
trunk_dim = hidden
if graph_readout == "mean":
self._pool = None
elif graph_readout == "attention":
self._pool = AttentionGraphReadout(trunk_dim, gate_hidden=max(32, min(128, hidden)))
else:
raise ValueError(f"Unsupported graph_readout: {graph_readout}")
elif model_type == "mlp":
self._pool = None
mlp_layers = []
in_dim = num_nodes * 3
for _ in range(num_layers):
@@ -471,7 +569,69 @@ class RFMModel(nn.Module):
self.trans_head = make_head(3)
self.rot_head = make_head(3)
self.torsion_head = make_head(num_torsions)
if torsion_head_mode == "global":
self.torsion_head = make_head(num_torsions)
self.torsion_pair_mlp = None
self.torsion_pair_in_norm = None
self.torsion_subgraphs = None
elif torsion_head_mode == "bond_pair":
if model_type != "gcn":
raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn")
if torsion_subgraphs is None or len(torsion_subgraphs) != num_torsions:
raise ValueError("bond_pair requires torsion_subgraphs built from sample.movable_mask")
# Global pooled context + mean pool over a GCN run on the movable-side induced subgraph only.
pair_in = d + hidden
ph = max(64, min(256, hidden * 2))
self.torsion_head = None
self.torsion_subgraphs = torsion_subgraphs
self.torsion_pair_in_norm = nn.LayerNorm(pair_in)
self.torsion_pair_mlp = nn.Sequential(
nn.Linear(pair_in, ph),
nn.SiLU(),
nn.Linear(ph, ph),
nn.SiLU(),
nn.Linear(ph, 1),
)
else:
raise ValueError(f"Unsupported torsion_head_mode: {torsion_head_mode}")
def _torsion_from_subgraph_specs(self, x: Data, pooled: torch.Tensor) -> torch.Tensor:
"""Per-torsion: same GCN weights on movable-side subgraph coords, then pooled + MLP."""
if self.num_torsions == 0:
return pooled.new_zeros((pooled.shape[0], 0))
assert self.torsion_subgraphs is not None
assert self.torsion_pair_mlp is not None and self.torsion_pair_in_norm is not None
if not hasattr(x, "ptr") or x.ptr is None:
raise RuntimeError("bond_pair torsion head requires Batch.ptr (use Batch.from_data_list)")
B = int(pooled.shape[0])
ptr = x.ptr.to(device=pooled.device)
device = pooled.device
dtype = pooled.dtype
outs: list[torch.Tensor] = []
for k in range(self.num_torsions):
spec = self.torsion_subgraphs[k]
nodes_f = spec.nodes_full.to(device=device, dtype=torch.long)
ei_loc = spec.edge_index_local.to(device=device)
graphs: list[Data] = []
for b in range(B):
base = int(ptr[b].item())
coords_sub = x.coords[base + nodes_f].to(dtype=dtype)
graphs.append(Data(x=coords_sub, edge_index=ei_loc))
batch_sub = Batch.from_data_list(graphs).to(device)
h = batch_sub.x
edge_index = batch_sub.edge_index
batch_vec = batch_sub.batch
for idx, conv in enumerate(self.convs):
h_new = self.act(conv(h, edge_index))
if self.gcn_residual and h_new.shape == h.shape:
h = h_new + self.gcn_skip[idx](h)
else:
h = h_new
sub_trunk = global_mean_pool(h, batch_vec)
feat = self.torsion_pair_in_norm(torch.cat([pooled, sub_trunk], dim=-1))
outs.append(self.torsion_pair_mlp(feat).squeeze(-1))
return torch.stack(outs, dim=1)
def forward(self, x: Data, time: torch.Tensor) -> RigidTorsionVelocity:
if self.model_type == "gcn":
@@ -484,7 +644,10 @@ class RFMModel(nn.Module):
h = h_new + self.gcn_skip[idx](h)
else:
h = h_new
trunk = global_mean_pool(h, batch)
if self._pool is None:
trunk = global_mean_pool(h, batch)
else:
trunk = self._pool(h, batch)
else:
bsz = time.shape[0]
coords = x.coords.reshape(bsz, self.num_nodes * 3)
@@ -496,7 +659,12 @@ class RFMModel(nn.Module):
pooled = self.shared_norm(torch.cat([trunk, time_feats], dim=-1))
trans = self.trans_head(pooled)
rot_omega = self.rot_head(pooled)
torsion = self.torsion_head(pooled)
if self.torsion_head_mode == "bond_pair":
assert self.model_type == "gcn"
torsion = self._torsion_from_subgraph_specs(x, pooled)
else:
assert self.torsion_head is not None
torsion = self.torsion_head(pooled)
return RigidTorsionVelocity(center=trans, rot_omega=rot_omega, torsion=torsion)
@@ -673,6 +841,18 @@ def main() -> int:
parser.add_argument("--weight-center", type=float, default=1.0, help="Loss weight for translation velocity.")
parser.add_argument("--weight-omega", type=float, default=2.0, help="Loss weight for rotation velocity.")
parser.add_argument("--weight-torsion", type=float, default=4.0, help="Loss weight for torsion velocity.")
parser.add_argument(
"--tail-risk-weight",
type=float,
default=0.0,
help="Extra weight on upper-quantile per-sample loss (0 disables).",
)
parser.add_argument(
"--tail-risk-quantile",
type=float,
default=0.9,
help="Upper quantile threshold for tail-risk penalty.",
)
parser.add_argument("--grad-clip", type=float, default=1.0, help="Gradient clipping max norm.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hidden", type=int, default=128, help="GCN hidden dim (larger => more GPU work)")
@@ -757,9 +937,27 @@ def main() -> int:
"--rotation-loss",
type=str,
default="mse",
choices=["mse", "geodesic"],
choices=["mse", "geodesic", "hybrid"],
help="Rotation-channel loss type.",
)
parser.add_argument(
"--rotation-hybrid-alpha",
type=float,
default=0.5,
help="For rotation-loss=hybrid: weight on geodesic term (0..1).",
)
parser.add_argument(
"--rotation-weight-start",
type=float,
default=-1.0,
help="If >=0, linearly anneal effective rotation weight from start to --weight-omega.",
)
parser.add_argument(
"--rotation-weight-warmup-epochs",
type=int,
default=0,
help="Warmup epochs for rotation weight schedule.",
)
parser.add_argument(
"--gcn-residual",
action="store_true",
@@ -776,6 +974,24 @@ def main() -> int:
default=1,
help="Per-channel head depth (>=1).",
)
parser.add_argument(
"--graph-readout",
type=str,
default="mean",
choices=["mean", "attention"],
help="GCN graph readout: mean pool vs learned attention pool (ignored for model-type=mlp).",
)
parser.add_argument(
"--torsion-head",
type=str,
default="global",
choices=["global", "bond_pair"],
help=(
"global: torsion vector from pooled trunk+time. "
"bond_pair (GCN only): per torsion k, run the same GCN trunk on the movable-side induced subgraph "
"(mask selects nodes/edges only; mask values are not fed as features), mean-pool, then MLP with global pooled context."
),
)
parser.add_argument(
"--early-stop-patience",
type=int,
@@ -800,6 +1016,12 @@ def main() -> int:
default=0,
help="Skip early-stop checks for first N epochs.",
)
parser.add_argument(
"--omega-max-norm",
type=float,
default=0.0,
help="If >0, clip predicted rotation omega vector norm for stability.",
)
args = parser.parse_args()
torch.manual_seed(args.seed)
@@ -819,6 +1041,13 @@ def main() -> int:
return 1
sample = prepare_sample(sample)
torsion_head_mode = args.torsion_head
if torsion_head_mode == "bond_pair" and args.model_type != "gcn":
print(
"Warning: --torsion-head bond_pair requires --model-type gcn; using global torsion head.",
file=sys.stderr,
)
torsion_head_mode = "global"
os.makedirs(args.out_dir, exist_ok=True)
save_visualization(sample, os.path.join(args.out_dir, "00_gt_coords.sdf"))
@@ -834,6 +1063,10 @@ def main() -> int:
torsion=sample.gt_torsion,
)
torsion_subgraphs: nn.ModuleList | None = None
if torsion_head_mode == "bond_pair":
torsion_subgraphs = build_torsion_subgraph_specs(sample)
model = RFMModel(
model_type=args.model_type,
num_nodes=int(sample.num_nodes),
@@ -843,6 +1076,9 @@ def main() -> int:
gcn_residual=args.gcn_residual,
channel_layernorm=args.channel_layernorm,
head_mlp_layers=args.head_mlp_layers,
graph_readout=args.graph_readout,
torsion_head_mode=torsion_head_mode,
torsion_subgraphs=torsion_subgraphs,
).to(device)
if args.compile and hasattr(torch, "compile"):
model = torch.compile(model) # type: ignore[assignment]
@@ -859,16 +1095,27 @@ def main() -> int:
f"center={args.weight_center} omega={args.weight_omega} torsion={args.weight_torsion}; "
f"grad_clip={args.grad_clip}"
)
print(f"tail_risk: weight={args.tail_risk_weight} quantile={args.tail_risk_quantile}")
print("loss domain=%s torsion_wrapped_loss=always_on" % args.loss_domain)
print(
f"rotation_loss={args.rotation_loss} gcn_residual={args.gcn_residual} "
f"channel_layernorm={args.channel_layernorm} head_mlp_layers={args.head_mlp_layers}"
f"channel_layernorm={args.channel_layernorm} head_mlp_layers={args.head_mlp_layers} "
f"graph_readout={args.graph_readout} torsion_head={torsion_head_mode}"
)
if args.rotation_loss == "hybrid":
print(f"rotation_hybrid_alpha={args.rotation_hybrid_alpha}")
if args.rotation_weight_start >= 0:
print(
"rotation_weight_schedule: "
f"start={args.rotation_weight_start} end={args.weight_omega} "
f"warmup_epochs={args.rotation_weight_warmup_epochs}"
)
print(
"early_stop: "
f"patience={args.early_stop_patience} min_delta={args.early_stop_min_delta} "
f"check_every={args.early_stop_check_every} warmup={args.early_stop_warmup}"
)
print(f"omega_max_norm={args.omega_max_norm}")
nw = args.num_workers
if nw < 0:
@@ -941,14 +1188,14 @@ def main() -> int:
time_remain = (1.0 - time_b).clamp(min=1e-6)
if args.loss_domain == "displacement":
pred_center = pred.center * time_remain
pred_omega = pred.rot_omega * time_remain
pred_omega = clip_vector_norm_batch(pred.rot_omega, args.omega_max_norm) * time_remain
pred_torsion = pred.torsion * time_remain
tgt_center = tc * time_remain
tgt_omega = tw * time_remain
tgt_torsion = tt * time_remain
else:
pred_center = pred.center
pred_omega = pred.rot_omega
pred_omega = clip_vector_norm_batch(pred.rot_omega, args.omega_max_norm)
pred_torsion = pred.torsion
tgt_center = tc
tgt_omega = tw
@@ -958,21 +1205,49 @@ def main() -> int:
scale_c = (tgt_center * tgt_center).mean().clamp(min=eps)
scale_w = (tgt_omega * tgt_omega).mean().clamp(min=eps)
scale_t = (tgt_torsion * tgt_torsion).mean().clamp(min=eps)
per_center = ((pred_center - tgt_center) ** 2).mean(dim=-1) / scale_c
if args.rotation_loss == "geodesic":
rot_loss = so3_geodesic_mse_loss(pred_omega, tgt_omega)
per_rot = ((pred_omega - tgt_omega) ** 2).mean(dim=-1) / scale_w
elif args.rotation_loss == "hybrid":
alpha = float(max(0.0, min(1.0, args.rotation_hybrid_alpha)))
rot_loss = (
alpha * so3_geodesic_mse_loss(pred_omega, tgt_omega)
+ (1.0 - alpha) * F.mse_loss(pred_omega, tgt_omega)
)
per_rot = ((pred_omega - tgt_omega) ** 2).mean(dim=-1) / scale_w
else:
rot_loss = F.mse_loss(pred_omega, tgt_omega)
mse = (
args.weight_center * (F.mse_loss(pred_center, tgt_center) / scale_c)
+ args.weight_omega * (rot_loss / scale_w)
+ args.weight_torsion
* (
torsion_wrapped_mse_loss(pred_torsion, tgt_torsion)
/ scale_t
per_rot = ((pred_omega - tgt_omega) ** 2).mean(dim=-1) / scale_w
per_torsion = (wrap_angle(pred_torsion - tgt_torsion) ** 2).mean(dim=-1) / scale_t
current_weight_omega = (
args.weight_omega
if args.rotation_weight_start < 0
else (
args.rotation_weight_start
+ (
max(0.0, min(1.0, epoch / max(1, args.rotation_weight_warmup_epochs)))
* (args.weight_omega - args.rotation_weight_start)
)
)
)
mse = (
args.weight_center * per_center.mean()
+ current_weight_omega * (rot_loss / scale_w)
+ args.weight_torsion * per_torsion.mean()
)
if args.tail_risk_weight > 0:
per_sample_total = (
args.weight_center * per_center
+ current_weight_omega * per_rot
+ args.weight_torsion * per_torsion
)
mse = mse + args.tail_risk_weight * upper_quantile_mean(
per_sample_total, args.tail_risk_quantile
)
if not torch.isfinite(mse):
mse = torch.tensor(1e3, device=device, dtype=torch.float32)
# Keep graph connectivity while preventing non-finite backward.
mse = torch.nan_to_num(mse, nan=1e3, posinf=1e3, neginf=1e3)
(mse / accum).backward()
sum_mse = sum_mse + mse.detach()
@@ -1057,6 +1332,7 @@ def main() -> int:
device=device,
num_runs=args.eval_runs,
rollout_steps=args.rollout_steps,
omega_max_norm=args.omega_max_norm,
)
print(f"final_test_mean_rmsd_{args.eval_runs}: {final_mean_rmsd_100:.6f}")
@@ -1132,7 +1408,7 @@ def main() -> int:
with torch.no_grad():
v = model(batch, tb)
vc = v.center.squeeze(0)
vw = v.rot_omega.squeeze(0)
vw = clip_vector_norm_batch(v.rot_omega, args.omega_max_norm).squeeze(0)
vt = v.torsion.squeeze(0)
dt = 1.0 / max(1, n_frames - 1)
state = RigidTorsionState(