Files
ai-rfm/train.py
demian3b d894949ef9 fix: save RFM ckpt metadata; reload model for trajectory generation.
Restores gcn_residual/graph_readout parity with training in
update_best_artifacts; migration reads BEST_PRACTICE command when
ckpt keys are missing.

Made-with: Cursor
2026-04-17 13:39:46 +09:00

1578 lines
59 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Consolidated RFM toy training + visualization from sandbox_rfm.ipynb.
- DataLoader with num_workers + prefetch builds CPU batches in parallel; tensors moved to GPU for train.
- pin_memory defaults False (fewer shm/pin-thread issues); enable with --pin-memory if stable.
"""
from __future__ import annotations
import argparse
import json
import math
import os
import pickle
import sys
import time
from collections import deque
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from rdkit import Chem
from rdkit.Chem import SDWriter
from rdkit.Geometry.rdGeometry import Point3D
from torch import nn
from torch_geometric.data import Batch, Data
from torch_geometric.nn import GCNConv, global_add_pool, global_mean_pool
from torch_geometric.utils import add_self_loops, softmax
# ---------------------------------------------------------------------------
# Geometry / torsion (from notebook)
# ---------------------------------------------------------------------------
def generate_random_trans(mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
return torch.randn(mean.shape, device=mean.device, dtype=mean.dtype) * std + mean
def apply_trans(x: Data, t_vec: torch.Tensor) -> Data:
x.coords = x.coords + t_vec
return x
def generate_random_quaternion() -> torch.Tensor:
u1, u2, u3 = torch.rand(3)
q1 = (1 - u1).sqrt() * torch.sin(2 * torch.pi * u2)
q2 = (1 - u1).sqrt() * torch.cos(2 * torch.pi * u2)
q3 = u1.sqrt() * torch.sin(2 * torch.pi * u3)
q4 = u1.sqrt() * torch.cos(2 * torch.pi * u3)
return torch.stack([q1, q2, q3, q4], dim=-1)
def generate_rotation_matrix(q: torch.Tensor) -> torch.Tensor:
rot = torch.tensor(
[
[
1 - 2 * q[2] ** 2 - 2 * q[3] ** 2,
2 * (q[1] * q[2] - q[0] * q[3]),
2 * (q[1] * q[3] + q[0] * q[2]),
],
[
2 * (q[1] * q[2] + q[0] * q[3]),
1 - 2 * q[1] ** 2 - 2 * q[3] ** 2,
2 * (q[2] * q[3] - q[0] * q[1]),
],
[
2 * (q[1] * q[3] - q[0] * q[2]),
2 * (q[2] * q[3] + q[0] * q[1]),
1 - 2 * q[1] ** 2 - 2 * q[2] ** 2,
],
],
device=q.device,
dtype=q.dtype,
)
return rot
def apply_rot_matrix(x: Data, rot_matrix: torch.Tensor) -> Data:
center = torch.mean(x.coords, dim=0)
coords = (x.coords - center) @ rot_matrix.T
x.coords = coords + center
return x
def _rodrigues(
points: torch.Tensor,
pivot: torch.Tensor,
axis_unit: torch.Tensor,
angle: torch.Tensor,
) -> torch.Tensor:
v = points - pivot
c, s = torch.cos(angle), torch.sin(angle)
cross = torch.cross(axis_unit.expand_as(v), v, dim=-1)
dot_u = (v * axis_unit).sum(-1, keepdim=True)
return pivot + v * c + cross * s + axis_unit * (dot_u * (1 - c))
def apply_torsion(
x: Any,
angles: torch.Tensor,
movable_mask: torch.Tensor,
bond_endpoints: torch.Tensor,
) -> Any:
coords = x.coords
B = angles.shape[0]
bond_endpoints = bond_endpoints.to(device=coords.device, dtype=torch.long)
movable_mask = movable_mask.to(device=coords.device)
angles = angles.to(device=coords.device, dtype=coords.dtype).reshape(B)
for b in range(B):
i, j = int(bond_endpoints[b, 0]), int(bond_endpoints[b, 1])
m = movable_mask[b]
d = coords[j] - coords[i]
u = d / d.norm()
idx = m.nonzero(as_tuple=True)[0]
coords.index_put_((idx,), _rodrigues(coords[idx], coords[i], u, angles[b]))
return x
def generate_random_torsion_angles(
num_bonds: int,
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
) -> torch.Tensor:
t = torch.empty(num_bonds, dtype=dtype, device=device)
t.uniform_(0, 2 * torch.pi)
return t
def _movable_j_side(edge_index: torch.Tensor, i: int, j: int, num_atoms: int) -> torch.Tensor:
ei = edge_index[:2]
adj = [[] for _ in range(num_atoms)]
for e in range(ei.shape[1]):
u, v = int(ei[0, e]), int(ei[1, e])
if {u, v} == {i, j}:
continue
adj[u].append(v)
adj[v].append(u)
seen = {j}
q = deque([j])
while q:
u = q.popleft()
for v in adj[u]:
if v not in seen:
seen.add(v)
q.append(v)
m = torch.zeros(num_atoms, dtype=torch.bool)
m[list(seen)] = True
return m
def bfs_torsion_masks(x: Any, root: int = 0) -> tuple[torch.Tensor, torch.Tensor]:
rot_pat = Chem.MolFromSmarts("[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]")
rot_set = {frozenset((int(a), int(b))) for a, b in x.rdmol.GetSubstructMatches(rot_pat)}
N = int(x.coords.shape[0])
ei = x.edge_index
adj = [[] for _ in range(N)]
for e in range(ei.shape[1]):
u, v = int(ei[0, e]), int(ei[1, e])
adj[u].append(v)
adj[v].append(u)
for u in range(N):
adj[u].sort()
depth = [-1] * N
depth[root] = 0
q = deque([root])
ordered_ij: list[tuple[int, int]] = []
tree_pairs: set[frozenset[int]] = set()
while q:
u = q.popleft()
for v in adj[u]:
if depth[v] != -1:
continue
depth[v] = depth[u] + 1
q.append(v)
fs = frozenset((u, v))
tree_pairs.add(fs)
if fs in rot_set:
ordered_ij.append((u, v))
def orient(fs: frozenset[int]) -> tuple[int, int]:
a, b = tuple(fs)
da, db = depth[a], depth[b]
if da < db or (da == db and a < b):
return (a, b)
return (b, a)
for fs in sorted(rot_set - tree_pairs, key=orient):
ordered_ij.append(orient(fs))
bond_endpoints = torch.tensor(ordered_ij, dtype=torch.long)
movable_mask = torch.stack([_movable_j_side(ei, i, j, N) for i, j in ordered_ij])
return bond_endpoints, movable_mask
def vel_center(current: torch.Tensor, target: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return (target - current) / (1 - t).clamp(min=1e-6)
def wrap_angle(d: torch.Tensor) -> torch.Tensor:
return torch.atan2(torch.sin(d), torch.cos(d))
def vel_torsion(current: torch.Tensor, target: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return wrap_angle(target - current) / (1 - t).clamp(min=1e-6)
def torsion_wrapped_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return torch.mean(wrap_angle(pred - target) ** 2)
def upper_quantile_mean(values: torch.Tensor, quantile: float) -> torch.Tensor:
"""Mean of upper-quantile tail values."""
q = float(max(0.0, min(1.0, quantile)))
if q <= 0.0:
return values.mean()
if q >= 1.0:
return values.max()
n = int(values.numel())
k = max(1, int(math.ceil(n * (1.0 - q))))
top_vals, _ = torch.topk(values.reshape(-1), k=k, largest=True)
return top_vals.mean()
def so3_geodesic_mse_loss(pred_omega: torch.Tensor, target_omega: torch.Tensor) -> torch.Tensor:
"""Geodesic MSE between batched axis-angle increments."""
if pred_omega.ndim != 2 or pred_omega.shape[-1] != 3:
raise ValueError("pred_omega must be [B,3]")
pred_omega = torch.nan_to_num(pred_omega, nan=0.0, posinf=0.0, neginf=0.0).clamp(-20.0, 20.0)
target_omega = torch.nan_to_num(target_omega, nan=0.0, posinf=0.0, neginf=0.0).clamp(-20.0, 20.0)
losses = []
for i in range(pred_omega.shape[0]):
r_pred = so3_exp(pred_omega[i])
r_tgt = so3_exp(target_omega[i])
delta = so3_log(r_pred.T @ r_tgt)
delta = torch.nan_to_num(delta, nan=0.0, posinf=0.0, neginf=0.0).clamp(-10.0, 10.0)
losses.append(torch.sum(delta * delta))
return torch.stack(losses).mean()
def clip_vector_norm_batch(x: torch.Tensor, max_norm: float) -> torch.Tensor:
if max_norm <= 0:
return x
norms = torch.linalg.norm(x, dim=-1, keepdim=True).clamp(min=1e-9)
scale = torch.clamp(max_norm / norms, max=1.0)
return x * scale
def skew(w: torch.Tensor) -> torch.Tensor:
z = torch.zeros((), device=w.device, dtype=w.dtype)
return torch.stack(
[
torch.stack([z, -w[2], w[1]]),
torch.stack([w[2], z, -w[0]]),
torch.stack([-w[1], w[0], z]),
]
)
def so3_log(R: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
tr = (R.trace() - 1) * 0.5
tr = tr.clamp(-1 + eps, 1 - eps)
theta = torch.acos(tr)
w = torch.stack(
[R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]]
) / (2 * torch.sin(theta) + eps)
return w * theta
def so3_exp(omega: torch.Tensor) -> torch.Tensor:
theta = omega.norm()
w = omega / (theta + 1e-12)
K = skew(w)
I = torch.eye(3, device=omega.device, dtype=omega.dtype)
return I + torch.sin(theta) * K + (1 - torch.cos(theta)) * (K @ K)
def vel_rot_mat(current: torch.Tensor, target: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
om = so3_log(current.T @ target)
return om / (1 - t).clamp(min=1e-6)
@dataclass
class RigidTorsionState:
center: torch.Tensor
rot: torch.Tensor
torsion: torch.Tensor
@dataclass
class RigidTorsionVelocity:
center: torch.Tensor
rot_omega: torch.Tensor
torsion: torch.Tensor
def rigid_torsion_velocity(
current: RigidTorsionState, target: RigidTorsionState, t: torch.Tensor
) -> RigidTorsionVelocity:
t = torch.as_tensor(t)
vc = vel_center(current.center, target.center, t)
vt = vel_torsion(current.torsion, target.torsion, t)
omega = vel_rot_mat(current.rot, target.rot, t)
return RigidTorsionVelocity(vc, omega, vt)
def get_state(
current: RigidTorsionState, velocity: RigidTorsionVelocity, t: float
) -> RigidTorsionState:
t = torch.as_tensor(t)
center = current.center + velocity.center * t
rot = current.rot @ so3_exp(velocity.rot_omega * t)
torsion = current.torsion + velocity.torsion * t
return RigidTorsionState(center, rot, torsion)
# ---------------------------------------------------------------------------
# Visualization
# ---------------------------------------------------------------------------
def save_visualization(sample: Data, path: str) -> None:
mol = deepcopy(sample.rdmol)
conformer = mol.GetConformer()
coords = sample.coords
for i, coord in enumerate(coords):
x, y, z = coord.tolist()
conformer.SetAtomPosition(i, Point3D(x, y, z))
Chem.MolToMolFile(mol, path)
def save_video_visualization(
sample: Data,
states: list[RigidTorsionState],
path: str,
) -> None:
mol = deepcopy(sample.rdmol)
sample.coords = sample.gt_coords.clone()
with open(path, "w") as f:
writer = SDWriter(f)
for state in states:
sample_copy = deepcopy(sample)
apply_trans(sample_copy, state.center)
apply_rot_matrix(sample_copy, state.rot)
apply_torsion(sample_copy, state.torsion, sample_copy.movable_mask, sample_copy.bond_endpoints)
coords = sample_copy.coords
conformer = mol.GetConformer()
for i, coord in enumerate(coords):
x, y, z = coord.tolist()
conformer.SetAtomPosition(i, Point3D(x, y, z))
writer.write(mol)
writer.close()
# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------
def kabsch_rmsd(pred: torch.Tensor, target: torch.Tensor) -> float:
"""RMSD after rigid alignment (Kabsch), shape: [N, 3]."""
if not torch.isfinite(pred).all() or not torch.isfinite(target).all():
return 1e6
p = pred - pred.mean(dim=0, keepdim=True)
q = target - target.mean(dim=0, keepdim=True)
h = p.T @ q
try:
u, _, v_t = torch.linalg.svd(h)
except RuntimeError:
return 1e6
r = v_t.T @ u.T
if torch.det(r) < 0:
v_t[-1, :] *= -1
r = v_t.T @ u.T
p_aligned = p @ r
return float(torch.sqrt(torch.mean(torch.sum((p_aligned - q) ** 2, dim=-1))).item())
def graph_from_state_cpu(sample: Data, center: torch.Tensor, rot: torch.Tensor, torsion: torch.Tensor) -> Data:
g = deepcopy(sample)
g.coords = sample.gt_coords.clone()
apply_trans(g, center)
apply_rot_matrix(g, rot)
apply_torsion(g, torsion, g.movable_mask, g.bond_endpoints)
return g
def evaluate_mean_rmsd_100(
model: nn.Module,
sample: Data,
target_coords: torch.Tensor,
device: torch.device,
num_runs: int = 100,
rollout_steps: int = 100,
omega_max_norm: float = 0.0,
) -> float:
"""Run 100 one-shot predictions to time=1 and average final-structure RMSD."""
del rollout_steps
model.eval()
nb = sample.bond_endpoints.shape[0]
centers = []
rots = []
torsions = []
for _ in range(num_runs):
centers.append(generate_random_trans(torch.zeros(3), torch.tensor([3.0, 3.0, 3.0])))
rots.append(generate_rotation_matrix(generate_random_quaternion()))
torsions.append(generate_random_torsion_angles(nb))
center_t = torch.stack(centers, dim=0) # [R,3]
rot_t = torch.stack(rots, dim=0) # [R,3,3]
torsion_t = torch.stack(torsions, dim=0) # [R,B]
graphs = [
graph_from_state_cpu(sample, center_t[i], rot_t[i], torsion_t[i])
for i in range(num_runs)
]
batch = Batch.from_data_list(graphs).to(device)
tb = torch.zeros((num_runs, 1), device=device, dtype=torch.float32)
with torch.no_grad():
v = model(batch, tb)
vc = v.center.detach().cpu()
vw = clip_vector_norm_batch(v.rot_omega.detach().cpu(), omega_max_norm)
vt = v.torsion.detach().cpu()
center_t = center_t + vc
torsion_t = torsion_t + vt
for i in range(num_runs):
rot_t[i] = rot_t[i] @ so3_exp(vw[i])
rmsds = []
for i in range(num_runs):
final_graph = graph_from_state_cpu(sample, center_t[i], rot_t[i], torsion_t[i])
rmsds.append(kabsch_rmsd(final_graph.coords, target_coords))
return float(sum(rmsds) / len(rmsds))
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class TorsionSubgraphSpec(nn.Module):
"""Movable-side induced subgraph for one torsion (full-graph node ids + local edge_index)."""
def __init__(self, nodes_full: torch.Tensor, edge_index_local: torch.Tensor):
super().__init__()
self.register_buffer("nodes_full", nodes_full, persistent=True)
self.register_buffer("edge_index_local", edge_index_local, persistent=True)
def build_torsion_subgraph_specs(sample: Data) -> nn.ModuleList:
"""Induced subgraph for torsion k: movable side ∪ rotatable bond endpoints (i,j).
Endpoints carry hinge geometry even when one atom lies on the ``fixed'' side of the mask.
"""
ei = sample.edge_index
if ei.device.type != "cpu":
ei = ei.cpu()
mv = sample.movable_mask
if mv.device.type != "cpu":
mv = mv.cpu()
be = sample.bond_endpoints
if be.device.type != "cpu":
be = be.cpu()
t_count = int(sample.bond_endpoints.shape[0])
modules: list[TorsionSubgraphSpec] = []
for k in range(t_count):
mask_row = mv[k].bool()
mov = torch.where(mask_row)[0].long()
i_t, j_t = int(be[k, 0].item()), int(be[k, 1].item())
hinge = torch.tensor([i_t, j_t], dtype=torch.long)
nodes_full = torch.unique(torch.cat([mov, hinge]), sorted=True)
if nodes_full.numel() == 0:
raise RuntimeError(f"torsion {k}: empty movable side; cannot build subgraph")
node_set = set(int(x) for x in nodes_full.tolist())
old_to_new = {old: i for i, old in enumerate(sorted(node_set))}
es: list[int] = []
ed: list[int] = []
for e in range(ei.shape[1]):
a, b = int(ei[0, e]), int(ei[1, e])
if a in node_set and b in node_set:
es.append(old_to_new[a])
ed.append(old_to_new[b])
if len(es) > 0:
edge_index_local = torch.tensor([es, ed], dtype=torch.long)
else:
nk = nodes_full.numel()
edge_index_local = torch.stack(
[torch.arange(nk, dtype=torch.long), torch.arange(nk, dtype=torch.long)], dim=0
)
modules.append(TorsionSubgraphSpec(nodes_full, edge_index_local))
return nn.ModuleList(modules)
class AttentionGraphReadout(nn.Module):
"""Learned softmax weights over nodes (distinct inductive bias vs mean pooling)."""
def __init__(self, node_dim: int, gate_hidden: int = 64):
super().__init__()
self.gate = nn.Sequential(
nn.Linear(node_dim, gate_hidden),
nn.SiLU(),
nn.Linear(gate_hidden, 1),
)
def forward(self, h: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
scores = self.gate(h).squeeze(-1)
alpha = softmax(scores, batch)
return global_add_pool(h * alpha.unsqueeze(-1), batch)
class RFMModel(nn.Module):
"""Model for velocity prediction. Supports GCN and fixed-order MLP encoders."""
def __init__(
self,
model_type: str,
num_nodes: int,
num_torsions: int,
hidden: int = 128,
num_layers: int = 4,
gcn_residual: bool = False,
channel_layernorm: bool = False,
head_mlp_layers: int = 1,
graph_readout: str = "mean",
torsion_head_mode: str = "global",
torsion_subgraphs: nn.ModuleList | None = None,
torsion_subgraph_output_clip: float = 0.0,
torsion_sub_gcn_layers: int | None = None,
):
super().__init__()
self.model_type = model_type
self.num_nodes = num_nodes
self.num_torsions = num_torsions
self.hidden = hidden
self.num_layers = num_layers
self.gcn_residual = gcn_residual
self.channel_layernorm = channel_layernorm
self.act = nn.SiLU()
self.head_mlp_layers = max(1, int(head_mlp_layers))
self.graph_readout = graph_readout
self.torsion_head_mode = torsion_head_mode
self.torsion_subgraph_output_clip = float(max(0.0, torsion_subgraph_output_clip))
self.sub_gcn_skip: nn.ModuleList | None = None
if model_type == "gcn":
self.convs = nn.ModuleList()
self.convs.append(GCNConv(3, hidden))
for _ in range(num_layers - 1):
self.convs.append(GCNConv(hidden, hidden))
self.gcn_skip = nn.ModuleList([nn.Identity()])
for _ in range(num_layers - 1):
self.gcn_skip.append(nn.Identity())
trunk_dim = hidden
if graph_readout == "mean":
self._pool = None
elif graph_readout == "attention":
self._pool = AttentionGraphReadout(trunk_dim, gate_hidden=max(32, min(128, hidden)))
else:
raise ValueError(f"Unsupported graph_readout: {graph_readout}")
elif model_type == "mlp":
self._pool = None
mlp_layers = []
in_dim = num_nodes * 3
for _ in range(num_layers):
mlp_layers.append(nn.Linear(in_dim, hidden))
mlp_layers.append(nn.SiLU())
in_dim = hidden
self.encoder = nn.Sequential(*mlp_layers)
trunk_dim = hidden
else:
raise ValueError(f"Unsupported model_type: {model_type}")
self.time_feat_dim = 3
d = trunk_dim + self.time_feat_dim
self.shared_norm = nn.LayerNorm(d) if channel_layernorm else nn.Identity()
def make_head(out_dim: int) -> nn.Sequential:
layers: list[nn.Module] = []
in_dim = d
for _ in range(self.head_mlp_layers - 1):
layers.append(nn.Linear(in_dim, d))
if channel_layernorm:
layers.append(nn.LayerNorm(d))
layers.append(nn.SiLU())
in_dim = d
layers.append(nn.Linear(in_dim, out_dim))
return nn.Sequential(*layers)
self.trans_head = make_head(3)
self.rot_head = make_head(3)
if torsion_head_mode == "global":
self.torsion_head = make_head(num_torsions)
self.torsion_pair_mlp = None
self.torsion_pair_in_norm = None
self.torsion_subgraphs = None
self.sub_convs = None
self.sub_pool_norm = None
self.sub_gcn_skip = None
elif torsion_head_mode == "bond_pair":
if model_type != "gcn":
raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn")
if torsion_subgraphs is None or len(torsion_subgraphs) != num_torsions:
raise ValueError("bond_pair requires torsion_subgraphs built from sample.movable_mask")
# Dedicated subgraph encoder (weights not shared with full-graph convs).
# First-layer input = coords ⊕ full-graph node embedding (same trunk forward); avoids
# throwing away trunk geometry context when re-encoding the movable fragment only.
ns = int(num_layers if torsion_sub_gcn_layers is None else torsion_sub_gcn_layers)
ns = max(1, min(ns, num_layers))
self.sub_num_layers = ns
self.sub_convs = nn.ModuleList()
self.sub_convs.append(GCNConv(3 + hidden, hidden))
for _ in range(ns - 1):
self.sub_convs.append(GCNConv(hidden, hidden))
self.sub_gcn_skip = nn.ModuleList([nn.Identity() for _ in range(ns)])
# Global pooled context + mean pool over subgraph GCN.
pair_in = d + hidden
ph = max(64, min(256, hidden * 2))
self.torsion_head = None
self.torsion_subgraphs = torsion_subgraphs
self.sub_pool_norm = nn.LayerNorm(hidden)
self.torsion_pair_in_norm = nn.LayerNorm(pair_in)
self.torsion_pair_mlp = nn.Sequential(
nn.Linear(pair_in, ph),
nn.SiLU(),
nn.Linear(ph, ph),
nn.SiLU(),
nn.Linear(ph, 1),
)
def _small_linear_init(m: nn.Module) -> None:
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=0.25)
if m.bias is not None:
nn.init.zeros_(m.bias)
self.torsion_pair_mlp.apply(_small_linear_init)
else:
raise ValueError(f"Unsupported torsion_head_mode: {torsion_head_mode}")
def _torsion_from_subgraph_specs(self, x: Data, pooled: torch.Tensor, h_full: torch.Tensor) -> torch.Tensor:
"""Batched subgraph GCN; node input is [coords | h_full] gathered onto each torsion subgraph."""
if self.num_torsions == 0:
return pooled.new_zeros((pooled.shape[0], 0))
assert self.torsion_subgraphs is not None
assert self.sub_convs is not None
assert self.torsion_pair_mlp is not None and self.torsion_pair_in_norm is not None
if not hasattr(x, "ptr") or x.ptr is None:
raise RuntimeError("bond_pair torsion head requires Batch.ptr (use Batch.from_data_list)")
B = int(pooled.shape[0])
T = self.num_torsions
ptr = x.ptr.to(device=pooled.device)
device = pooled.device
dtype = pooled.dtype
graphs: list[Data] = []
for k in range(T):
spec = self.torsion_subgraphs[k]
nodes_f = spec.nodes_full.to(device=device, dtype=torch.long)
ei_loc = spec.edge_index_local.to(device=device)
for b in range(B):
base = int(ptr[b].item())
idx = base + nodes_f
coords_sub = x.coords[idx].to(dtype=dtype)
h_sub = h_full[idx].to(dtype=dtype)
x_sub = torch.cat([coords_sub, h_sub], dim=-1)
graphs.append(Data(x=x_sub, edge_index=ei_loc))
batch_sub = Batch.from_data_list(graphs).to(device)
h = batch_sub.x
edge_index, _ = add_self_loops(batch_sub.edge_index, num_nodes=batch_sub.num_nodes)
batch_vec = batch_sub.batch
assert self.sub_gcn_skip is not None
for idx, conv in enumerate(self.sub_convs):
h_new = self.act(conv(h, edge_index))
if self.gcn_residual and h_new.shape == h.shape:
h = h_new + self.sub_gcn_skip[idx](h)
else:
h = h_new
sub_pooled = global_mean_pool(h, batch_vec)
assert self.sub_pool_norm is not None
sub_pooled = self.sub_pool_norm(sub_pooled)
sub_bt = sub_pooled.view(T, B, self.hidden).transpose(0, 1)
pooled_exp = pooled.unsqueeze(1).expand(B, T, pooled.shape[-1])
feat = torch.cat([pooled_exp, sub_bt], dim=-1).reshape(B * T, -1)
feat = self.torsion_pair_in_norm(feat)
out = self.torsion_pair_mlp(feat).squeeze(-1).view(B, T)
out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
c = self.torsion_subgraph_output_clip
if c > 0:
out = out.clamp(min=-c, max=c)
return out
def forward(self, x: Data, time: torch.Tensor) -> RigidTorsionVelocity:
if self.model_type == "gcn":
h = x.coords
edge_index = x.edge_index
batch = x.batch
for idx, conv in enumerate(self.convs):
h_new = self.act(conv(h, edge_index))
if self.gcn_residual and h_new.shape == h.shape:
h = h_new + self.gcn_skip[idx](h)
else:
h = h_new
if self._pool is None:
trunk = global_mean_pool(h, batch)
else:
trunk = self._pool(h, batch)
else:
bsz = time.shape[0]
coords = x.coords.reshape(bsz, self.num_nodes * 3)
trunk = self.encoder(coords)
t = time.to(dtype=trunk.dtype, device=trunk.device)
two_pi_t = 2.0 * torch.pi * t
time_feats = torch.cat([t, torch.sin(two_pi_t), torch.cos(two_pi_t)], dim=-1)
pooled = self.shared_norm(torch.cat([trunk, time_feats], dim=-1))
trans = self.trans_head(pooled)
rot_omega = self.rot_head(pooled)
if self.torsion_head_mode == "bond_pair":
assert self.model_type == "gcn"
torsion = self._torsion_from_subgraph_specs(x, pooled, h)
else:
assert self.torsion_head is not None
torsion = self.torsion_head(pooled)
return RigidTorsionVelocity(center=trans, rot_omega=rot_omega, torsion=torsion)
# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def load_sample_from_pickle(path: str, pos: int) -> Data:
with open(path, "rb") as f:
f.seek(pos)
raw = pickle.load(f)
raw.edge_index = raw.edge_index[:2, :]
d = Data(coords=raw.coords, edge_index=raw.edge_index)
d.num_nodes = int(raw.coords.shape[0])
d.rdmol = raw.rdmol
for name in ("gt_coords", "gt_trans", "gt_quaternion", "gt_torsion", "bond_endpoints", "movable_mask"):
if hasattr(raw, name):
setattr(d, name, getattr(raw, name))
return d
def load_sample_from_sdf(sdf_path: str) -> Data:
mol = Chem.MolFromMolFile(sdf_path, removeHs=False)
if mol is None:
raise RuntimeError(f"RDKit could not read {sdf_path}")
conf = mol.GetConformer()
n = mol.GetNumAtoms()
coords = torch.tensor(
[[conf.GetAtomPosition(i).x, conf.GetAtomPosition(i).y, conf.GetAtomPosition(i).z] for i in range(n)],
dtype=torch.float32,
)
edges: list[tuple[int, int]] = []
for b in mol.GetBonds():
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
edges.append((i, j))
edges.append((j, i))
edge_index = torch.tensor(edges, dtype=torch.long).T.contiguous()
data = Data()
data.coords = coords
data.edge_index = edge_index
data.num_nodes = int(coords.shape[0])
data.rdmol = Chem.Mol(mol)
return data
def prepare_sample(sample: Data) -> Data:
sample.num_nodes = int(sample.coords.shape[0])
sample.gt_coords = sample.coords.clone()
sample.gt_trans = torch.mean(sample.coords, dim=0, dtype=torch.float32)
sample.gt_quaternion = torch.tensor((0, 0, 0, 0), dtype=torch.float32)
sample.bond_endpoints, sample.movable_mask = bfs_torsion_masks(sample)
sample.gt_torsion = torch.zeros(sample.bond_endpoints.shape[0], dtype=torch.float32)
return sample
# ---------------------------------------------------------------------------
# CPU batch building (DataLoader workers or main process)
# ---------------------------------------------------------------------------
def _gt_state_cpu(gs: RigidTorsionState) -> RigidTorsionState:
return RigidTorsionState(
center=gs.center.cpu(),
rot=gs.rot.cpu(),
torsion=gs.torsion.cpu(),
)
def make_rfm_example(
sample: Data,
gt_state_cpu: RigidTorsionState,
cpu: torch.device,
time_power: float = 1.0,
) -> dict[str, Any]:
"""One random pose + velocity targets (CPU tensors). Safe for DataLoader workers."""
random_trans = generate_random_trans(
torch.tensor([0, 0, 0], dtype=torch.float32, device=cpu),
torch.tensor([3, 3, 3], dtype=torch.float32, device=cpu),
)
random_rot = generate_rotation_matrix(generate_random_quaternion())
random_torsion = generate_random_torsion_angles(sample.bond_endpoints.shape[0], device=cpu)
cur = RigidTorsionState(center=random_trans, rot=random_rot, torsion=random_torsion)
velocity = rigid_torsion_velocity(cur, gt_state_cpu, torch.tensor(0.0, device=cpu))
t_scalar = torch.rand((), device=cpu) ** time_power
new_state = get_state(cur, velocity, float(t_scalar.item()))
g = deepcopy(sample)
g.coords = g.coords.to(device=cpu, dtype=torch.float32)
g.edge_index = g.edge_index.to(device=cpu)
if hasattr(g, "movable_mask"):
g.movable_mask = g.movable_mask.to(device=cpu)
if hasattr(g, "bond_endpoints"):
g.bond_endpoints = g.bond_endpoints.to(device=cpu)
apply_trans(g, new_state.center)
apply_rot_matrix(g, new_state.rot)
apply_torsion(g, new_state.torsion, g.movable_mask, g.bond_endpoints)
return {
"data": g,
"v_center": velocity.center.detach().cpu(),
"v_omega": velocity.rot_omega.detach().cpu(),
"v_torsion": velocity.torsion.detach().cpu(),
"time": t_scalar.detach().cpu(),
}
def rfm_collate(batch: list[dict[str, Any]]) -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
graphs = [b["data"] for b in batch]
batched = Batch.from_data_list(graphs)
time_b = torch.stack([b["time"] for b in batch], dim=0).unsqueeze(-1)
target_center = torch.stack([b["v_center"] for b in batch], dim=0)
target_omega = torch.stack([b["v_omega"] for b in batch], dim=0)
target_torsion = torch.stack([b["v_torsion"] for b in batch], dim=0)
return batched, time_b, target_center, target_omega, target_torsion
class SameGraphRandomRFMDataset(Dataset):
"""Same topology (`sample`); each __getitem__ is a new random RFM instance (CPU)."""
def __init__(
self,
sample: Data,
gt_state_cpu: RigidTorsionState,
length: int,
time_power: float = 1.0,
):
self.sample = sample
self.gt_state = _gt_state_cpu(gt_state_cpu)
self.length = length
self.time_power = time_power
def __len__(self) -> int:
return self.length
def __getitem__(self, idx: int) -> dict[str, Any]:
del idx
return make_rfm_example(
self.sample,
self.gt_state,
torch.device("cpu"),
self.time_power,
)
def build_one_batch(
sample: Data,
gt_state: RigidTorsionState,
batch_size: int,
cpu: torch.device,
time_power: float = 1.0,
) -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Single-process batch (no workers); same tensors as DataLoader+collate."""
rows = [
make_rfm_example(sample, gt_state, cpu, time_power=time_power)
for _ in range(batch_size)
]
return rfm_collate(rows)
def infinite_batches(loader: DataLoader):
while True:
for batch in loader:
yield batch
def main() -> int:
here = os.path.dirname(os.path.abspath(__file__))
parser = argparse.ArgumentParser()
parser.add_argument("--pickle-path", default="", help="Optional pickle shard path")
parser.add_argument("--pickle-pos", type=int, default=61908729)
parser.add_argument("--sdf", default=os.path.join(here, "sample.sdf"), help="Fallback ligand SDF")
parser.add_argument("--out-dir", default=os.path.join(here, "rfm_out"))
parser.add_argument("--epochs", type=int, default=2000)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--weight-center", type=float, default=1.0, help="Loss weight for translation velocity.")
parser.add_argument("--weight-omega", type=float, default=2.0, help="Loss weight for rotation velocity.")
parser.add_argument("--weight-torsion", type=float, default=4.0, help="Loss weight for torsion velocity.")
parser.add_argument(
"--tail-risk-weight",
type=float,
default=0.0,
help="Extra weight on upper-quantile per-sample loss (0 disables).",
)
parser.add_argument(
"--tail-risk-quantile",
type=float,
default=0.9,
help="Upper quantile threshold for tail-risk penalty.",
)
parser.add_argument("--grad-clip", type=float, default=1.0, help="Gradient clipping max norm.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hidden", type=int, default=128, help="GCN hidden dim (larger => more GPU work)")
parser.add_argument(
"--model-type",
type=str,
default="mlp",
choices=["mlp", "gcn"],
help="Backbone encoder for velocity prediction.",
)
parser.add_argument(
"--gcn-layers",
type=int,
default=4,
help="Number of GCNConv layers (more => longer GPU kernels)",
)
parser.add_argument(
"--accum",
type=int,
default=1,
help="Gradient accumulation steps per optimizer step (keeps GPU busier vs CPU batch build)",
)
parser.add_argument(
"--compile",
action="store_true",
help="torch.compile(model) when supported (PyTorch 2+)",
)
parser.add_argument(
"--num-workers",
type=int,
default=-1,
help="DataLoader workers; -1 = min(8, CPU-1), 0 = main process only",
)
parser.add_argument(
"--prefetch-factor",
type=int,
default=4,
help="Per-worker prefetch queue depth (ignored if num-workers=0)",
)
parser.add_argument(
"--pin-memory",
action="store_true",
help="DataLoader pin_memory (faster H2D; can hit shm/pin issues in some setups)",
)
parser.add_argument(
"--sync-batch",
action="store_true",
help="Build batches in main thread (no DataLoader workers)",
)
parser.add_argument(
"--dataset-len",
type=int,
default=1_000_000,
help="Synthetic Dataset length (only affects shuffle / epoch size)",
)
parser.add_argument(
"--eval-runs",
type=int,
default=100,
help="Number of rollout tests for final RMSD report",
)
parser.add_argument(
"--rollout-steps",
type=int,
default=100,
help="Integration steps per rollout during final test",
)
parser.add_argument(
"--time-power",
type=float,
default=1.0,
help="Sample t as U(0,1)^time_power; >1 biases training toward t~0 states.",
)
parser.add_argument(
"--loss-domain",
type=str,
default="displacement",
choices=["velocity", "displacement"],
help="Train on raw velocity or displacement-to-target (v * (1 - t)).",
)
parser.add_argument(
"--rotation-loss",
type=str,
default="mse",
choices=["mse", "geodesic", "hybrid"],
help="Rotation-channel loss type.",
)
parser.add_argument(
"--rotation-hybrid-alpha",
type=float,
default=0.5,
help="For rotation-loss=hybrid: weight on geodesic term (0..1).",
)
parser.add_argument(
"--rotation-weight-start",
type=float,
default=-1.0,
help="If >=0, linearly anneal effective rotation weight from start to --weight-omega.",
)
parser.add_argument(
"--rotation-weight-warmup-epochs",
type=int,
default=0,
help="Warmup epochs for rotation weight schedule.",
)
parser.add_argument(
"--gcn-residual",
action="store_true",
help="Enable residual skip in GCN trunk.",
)
parser.add_argument(
"--channel-layernorm",
action="store_true",
help="Apply layernorm to pooled trunk and head hidden layers.",
)
parser.add_argument(
"--head-mlp-layers",
type=int,
default=1,
help="Per-channel head depth (>=1).",
)
parser.add_argument(
"--graph-readout",
type=str,
default="mean",
choices=["mean", "attention"],
help="GCN graph readout: mean pool vs learned attention pool (ignored for model-type=mlp).",
)
parser.add_argument(
"--torsion-head",
type=str,
default="global",
choices=["global", "bond_pair"],
help=(
"global: torsion vector from pooled trunk+time. "
"bond_pair (GCN only): induced subgraph = movable ∪ bond (i,j); subgraph node input = concat(coords, "
"full-graph trunk node h); dedicated sub_convs; batched B×T forward; MLP on [global pooled || subgraph pool]."
),
)
parser.add_argument(
"--subgraph-lr-scale",
type=float,
default=0.3,
help="bond_pair: LR multiplier for subgraph encoder + torsion head (main trunk uses --lr).",
)
parser.add_argument(
"--subgraph-torsion-clip",
type=float,
default=6.0,
help="bond_pair: clamp predicted torsion channel to [-c, c]; use 0 to disable.",
)
parser.add_argument(
"--torsion-sub-gcn-layers",
type=int,
default=-1,
help="bond_pair: subgraph GCN depth (-1 = same as --gcn-layers).",
)
parser.add_argument(
"--early-stop-patience",
type=int,
default=0,
help="Train-loss early stop patience in checks (0 disables).",
)
parser.add_argument(
"--early-stop-min-delta",
type=float,
default=0.0,
help="Minimum train-loss improvement required to reset patience.",
)
parser.add_argument(
"--early-stop-check-every",
type=int,
default=1,
help="Evaluate early-stop condition every N epochs.",
)
parser.add_argument(
"--early-stop-warmup",
type=int,
default=0,
help="Skip early-stop checks for first N epochs.",
)
parser.add_argument(
"--omega-max-norm",
type=float,
default=0.0,
help="If >0, clip predicted rotation omega vector norm for stability.",
)
parser.add_argument(
"--ema-decay",
type=float,
default=0.0,
help=(
"If in (0,1), maintain an EMA shadow of model weights after each optimizer step; "
"best_train checkpoint and final eval load the EMA weights (training uses online weights)."
),
)
args = parser.parse_args()
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu = torch.device("cpu")
if args.pickle_path and os.path.isfile(args.pickle_path):
print("Loading sample from pickle:", args.pickle_path, "pos", args.pickle_pos)
sample = load_sample_from_pickle(args.pickle_path, args.pickle_pos)
elif os.path.isfile(args.sdf):
print("Loading sample from SDF:", args.sdf)
sample = load_sample_from_sdf(args.sdf)
else:
print("No pickle or SDF found. Provide --pickle-path or --sdf", file=sys.stderr)
return 1
sample = prepare_sample(sample)
torsion_head_mode = args.torsion_head
if torsion_head_mode == "bond_pair" and args.model_type != "gcn":
print(
"Warning: --torsion-head bond_pair requires --model-type gcn; using global torsion head.",
file=sys.stderr,
)
torsion_head_mode = "global"
os.makedirs(args.out_dir, exist_ok=True)
save_visualization(sample, os.path.join(args.out_dir, "00_gt_coords.sdf"))
gt_state = RigidTorsionState(
center=torch.tensor([0, 0, 0], dtype=torch.float32, device=cpu),
rot=generate_rotation_matrix(torch.tensor([0, 0, 0, 0], dtype=torch.float32)),
torsion=torch.zeros(sample.bond_endpoints.shape[0], dtype=torch.float32),
)
target_state = RigidTorsionState(
center=torch.tensor([0, 0, 0], dtype=torch.float32, device=cpu),
rot=generate_rotation_matrix(sample.gt_quaternion),
torsion=sample.gt_torsion,
)
torsion_subgraphs: nn.ModuleList | None = None
if torsion_head_mode == "bond_pair":
torsion_subgraphs = build_torsion_subgraph_specs(sample)
torsion_sub_gcn_layers = None if int(args.torsion_sub_gcn_layers) < 0 else int(args.torsion_sub_gcn_layers)
model = RFMModel(
model_type=args.model_type,
num_nodes=int(sample.num_nodes),
num_torsions=int(sample.bond_endpoints.shape[0]),
hidden=args.hidden,
num_layers=args.gcn_layers,
gcn_residual=args.gcn_residual,
channel_layernorm=args.channel_layernorm,
head_mlp_layers=args.head_mlp_layers,
graph_readout=args.graph_readout,
torsion_head_mode=torsion_head_mode,
torsion_subgraphs=torsion_subgraphs,
torsion_subgraph_output_clip=float(args.subgraph_torsion_clip),
torsion_sub_gcn_layers=torsion_sub_gcn_layers,
).to(device)
subgraph_param_ids: set[int] = set()
if torsion_head_mode == "bond_pair":
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if name.startswith(("sub_convs", "sub_gcn_skip", "torsion_pair", "sub_pool_norm")):
subgraph_param_ids.add(id(p))
if args.compile and hasattr(torch, "compile"):
model = torch.compile(model) # type: ignore[assignment]
if torsion_head_mode == "bond_pair":
base_params: list[nn.Parameter] = []
sub_params: list[nn.Parameter] = []
for p in model.parameters():
if not p.requires_grad:
continue
if id(p) in subgraph_param_ids:
sub_params.append(p)
else:
base_params.append(p)
optimizer = torch.optim.Adam(
[
{"params": base_params, "lr": args.lr},
{"params": sub_params, "lr": args.lr * float(args.subgraph_lr_scale)},
],
)
else:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
def model_core(m: nn.Module) -> nn.Module:
return m._orig_mod if hasattr(m, "_orig_mod") else m
ema_decay = float(args.ema_decay)
ema_state: dict[str, torch.Tensor] | None = None
if 0.0 < ema_decay < 1.0:
core0 = model_core(model)
ema_state = {k: v.detach().clone() for k, v in core0.state_dict().items()}
@torch.no_grad()
def ema_update_(shadow: dict[str, torch.Tensor], m: nn.Module, decay: float) -> None:
src = model_core(m).state_dict()
for k, dst in shadow.items():
t = src[k]
if torch.is_floating_point(t):
dst.mul_(decay).add_(t, alpha=1.0 - decay)
else:
dst.copy_(t)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(
f"device={device} model={args.model_type} hidden={args.hidden} "
f"layers={args.gcn_layers} params={n_params:,}"
)
print(f"time sampling power={args.time_power}")
print(
"loss weights: "
f"center={args.weight_center} omega={args.weight_omega} torsion={args.weight_torsion}; "
f"grad_clip={args.grad_clip}"
)
print(f"tail_risk: weight={args.tail_risk_weight} quantile={args.tail_risk_quantile}")
print("loss domain=%s torsion_wrapped_loss=always_on" % args.loss_domain)
print(
f"rotation_loss={args.rotation_loss} gcn_residual={args.gcn_residual} "
f"channel_layernorm={args.channel_layernorm} head_mlp_layers={args.head_mlp_layers} "
f"graph_readout={args.graph_readout} torsion_head={torsion_head_mode}"
)
if torsion_head_mode == "bond_pair":
sub_l = args.gcn_layers if int(args.torsion_sub_gcn_layers) < 0 else int(args.torsion_sub_gcn_layers)
print(
f"bond_pair_stab: subgraph_lr_scale={args.subgraph_lr_scale} "
f"subgraph_torsion_clip={args.subgraph_torsion_clip} "
f"torsion_sub_gcn_layers={sub_l}"
)
if args.rotation_loss == "hybrid":
print(f"rotation_hybrid_alpha={args.rotation_hybrid_alpha}")
if args.rotation_weight_start >= 0:
print(
"rotation_weight_schedule: "
f"start={args.rotation_weight_start} end={args.weight_omega} "
f"warmup_epochs={args.rotation_weight_warmup_epochs}"
)
print(
"early_stop: "
f"patience={args.early_stop_patience} min_delta={args.early_stop_min_delta} "
f"check_every={args.early_stop_check_every} warmup={args.early_stop_warmup}"
)
print(f"omega_max_norm={args.omega_max_norm}")
if ema_state is not None:
print(f"ema_decay={ema_decay} (best checkpoint + final eval use EMA weights)")
nw = args.num_workers
if nw < 0:
nw = min(8, max(0, (os.cpu_count() or 4) - 1))
pf = args.prefetch_factor if nw > 0 else None
persistent = nw > 0
if args.sync_batch:
def get_train_batch() -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return build_one_batch(
sample,
gt_state,
args.batch_size,
cpu,
args.time_power,
)
print(f"batching: sync (main thread) workers=0")
else:
dataset = SameGraphRandomRFMDataset(
sample,
gt_state_cpu=gt_state,
length=args.dataset_len,
time_power=args.time_power,
)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=nw,
prefetch_factor=pf,
persistent_workers=persistent,
pin_memory=args.pin_memory and torch.cuda.is_available(),
collate_fn=rfm_collate,
)
_batch_iter = infinite_batches(loader)
def get_train_batch() -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return next(_batch_iter)
print(
f"batching: DataLoader num_workers={nw} prefetch_factor={pf!s} "
f"pin_memory={args.pin_memory and torch.cuda.is_available()} "
f"persistent_workers={persistent}"
)
t0 = time.perf_counter()
log_every = max(1, args.epochs // 10)
accum = max(1, args.accum)
best_train_mse = float("inf")
best_state: dict[str, torch.Tensor] | None = None
early_stop_best = float("inf")
early_stop_bad_checks = 0
stop_reason = "max_epochs"
stop_epoch = args.epochs - 1
for epoch in range(args.epochs):
optimizer.zero_grad(set_to_none=True)
sum_mse = torch.zeros((), device=device)
for _ in range(accum):
batched, time_b, tc, tw, tt = get_train_batch()
batched = batched.to(device, non_blocking=True)
time_b = time_b.to(device=device, dtype=torch.float32, non_blocking=True)
tc = tc.to(device=device, dtype=torch.float32, non_blocking=True)
tw = tw.to(device=device, dtype=torch.float32, non_blocking=True)
tt = tt.to(device=device, dtype=torch.float32, non_blocking=True)
pred = model(batched, time_b)
time_remain = (1.0 - time_b).clamp(min=1e-6)
if args.loss_domain == "displacement":
pred_center = pred.center * time_remain
pred_omega = clip_vector_norm_batch(pred.rot_omega, args.omega_max_norm) * time_remain
pred_torsion = pred.torsion * time_remain
tgt_center = tc * time_remain
tgt_omega = tw * time_remain
tgt_torsion = tt * time_remain
else:
pred_center = pred.center
pred_omega = clip_vector_norm_batch(pred.rot_omega, args.omega_max_norm)
pred_torsion = pred.torsion
tgt_center = tc
tgt_omega = tw
tgt_torsion = tt
eps = 1e-8
scale_c = (tgt_center * tgt_center).mean().clamp(min=eps)
scale_w = (tgt_omega * tgt_omega).mean().clamp(min=eps)
scale_t = (tgt_torsion * tgt_torsion).mean().clamp(min=eps)
per_center = ((pred_center - tgt_center) ** 2).mean(dim=-1) / scale_c
if args.rotation_loss == "geodesic":
rot_loss = so3_geodesic_mse_loss(pred_omega, tgt_omega)
per_rot = ((pred_omega - tgt_omega) ** 2).mean(dim=-1) / scale_w
elif args.rotation_loss == "hybrid":
alpha = float(max(0.0, min(1.0, args.rotation_hybrid_alpha)))
rot_loss = (
alpha * so3_geodesic_mse_loss(pred_omega, tgt_omega)
+ (1.0 - alpha) * F.mse_loss(pred_omega, tgt_omega)
)
per_rot = ((pred_omega - tgt_omega) ** 2).mean(dim=-1) / scale_w
else:
rot_loss = F.mse_loss(pred_omega, tgt_omega)
per_rot = ((pred_omega - tgt_omega) ** 2).mean(dim=-1) / scale_w
per_torsion = (wrap_angle(pred_torsion - tgt_torsion) ** 2).mean(dim=-1) / scale_t
current_weight_omega = (
args.weight_omega
if args.rotation_weight_start < 0
else (
args.rotation_weight_start
+ (
max(0.0, min(1.0, epoch / max(1, args.rotation_weight_warmup_epochs)))
* (args.weight_omega - args.rotation_weight_start)
)
)
)
mse = (
args.weight_center * per_center.mean()
+ current_weight_omega * (rot_loss / scale_w)
+ args.weight_torsion * per_torsion.mean()
)
if args.tail_risk_weight > 0:
per_sample_total = (
args.weight_center * per_center
+ current_weight_omega * per_rot
+ args.weight_torsion * per_torsion
)
mse = mse + args.tail_risk_weight * upper_quantile_mean(
per_sample_total, args.tail_risk_quantile
)
if not torch.isfinite(mse):
# Keep graph connectivity while preventing non-finite backward.
mse = torch.nan_to_num(mse, nan=1e3, posinf=1e3, neginf=1e3)
(mse / accum).backward()
sum_mse = sum_mse + mse.detach()
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
if ema_state is not None:
ema_update_(ema_state, model, ema_decay)
avg_mse = sum_mse / accum
avg_mse_val = float(avg_mse.detach().cpu())
if avg_mse_val < best_train_mse:
best_train_mse = avg_mse_val
core = model_core(model)
if ema_state is not None:
best_state = {k: v.detach().cpu().clone() for k, v in ema_state.items()}
else:
best_state = {k: v.detach().cpu().clone() for k, v in core.state_dict().items()}
should_check_early_stop = (
args.early_stop_patience > 0
and epoch >= args.early_stop_warmup
and ((epoch - args.early_stop_warmup) % max(1, args.early_stop_check_every) == 0)
)
if should_check_early_stop:
improved = avg_mse_val < (early_stop_best - args.early_stop_min_delta)
if improved:
early_stop_best = avg_mse_val
early_stop_bad_checks = 0
else:
early_stop_bad_checks += 1
if early_stop_bad_checks >= args.early_stop_patience:
stop_reason = (
"early_stop_train_loss_plateau"
f"(patience={args.early_stop_patience},"
f"min_delta={args.early_stop_min_delta},"
f"check_every={max(1, args.early_stop_check_every)})"
)
stop_epoch = epoch
print(
f"early stopping at epoch {epoch}: "
f"bad_checks={early_stop_bad_checks}"
)
break
if epoch % log_every == 0 or epoch == args.epochs - 1:
print(f"epoch {epoch:6d} train_mse {avg_mse_val:.6f} best_train_mse {best_train_mse:.6f}")
stop_epoch = epoch
if device.type == "cuda":
torch.cuda.synchronize()
print(
f"training done in {time.perf_counter() - t0:.1f}s "
f"(stop_reason={stop_reason}, stop_epoch={stop_epoch})"
)
ckpt_path = ""
if best_state is not None:
core = model_core(model)
core.load_state_dict(best_state)
print(f"loaded best model from training (best_train_mse={best_train_mse:.6f})")
artifacts_dir = os.path.join(here, "artifacts")
os.makedirs(artifacts_dir, exist_ok=True)
ckpt_path = os.path.join(artifacts_dir, "latest_eval_best_model.pt")
ckpt = {
"state_dict": best_state,
"model_type": args.model_type,
"hidden": int(args.hidden),
"num_layers": int(args.gcn_layers),
"num_nodes": int(sample.num_nodes),
"num_torsions": int(sample.bond_endpoints.shape[0]),
"sample_sdf": args.sdf,
"sample_pickle_path": args.pickle_path,
"sample_pickle_pos": int(args.pickle_pos),
"eval_runs": int(args.eval_runs),
"rollout_steps": int(args.rollout_steps),
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
# RFMModel forward parity (required for trajectory / tooling reloads)
"gcn_residual": bool(args.gcn_residual),
"channel_layernorm": bool(args.channel_layernorm),
"head_mlp_layers": int(args.head_mlp_layers),
"graph_readout": str(args.graph_readout),
"torsion_head": str(torsion_head_mode),
"subgraph_torsion_clip": float(args.subgraph_torsion_clip),
"torsion_sub_gcn_layers": int(args.torsion_sub_gcn_layers),
"omega_max_norm": float(args.omega_max_norm),
}
torch.save(ckpt, ckpt_path)
print(f"saved best checkpoint: {ckpt_path}")
# Final test metric: 100 rollouts mean final-structure RMSD
final_mean_rmsd_100 = evaluate_mean_rmsd_100(
model,
sample,
sample.gt_coords,
device=device,
num_runs=args.eval_runs,
rollout_steps=args.rollout_steps,
omega_max_norm=args.omega_max_norm,
)
print(f"final_test_mean_rmsd_{args.eval_runs}: {final_mean_rmsd_100:.6f}")
reports_dir = os.path.join(here, "reports")
os.makedirs(reports_dir, exist_ok=True)
latest_path = os.path.join(reports_dir, "latest_eval.json")
timestamp = datetime.now(timezone.utc).isoformat()
cmd = " ".join(sys.argv)
latest_report = {
"mean_rmsd_100": float(final_mean_rmsd_100),
"num_runs": int(args.eval_runs),
"timestamp_utc": timestamp,
"command": cmd,
"notes": "Final test metric from 100 random-initialized rollouts to time=1.",
"best_train_mse": float(best_train_mse),
"model_source": "best_train_checkpoint",
"checkpoint_path": ckpt_path,
"stop_reason": stop_reason,
"stop_epoch": int(stop_epoch),
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
}
with open(latest_path, "w", encoding="utf-8") as f:
json.dump(latest_report, f, indent=2)
# --- Visualizations ---
sample.coords = sample.gt_coords.clone()
save_visualization(sample, os.path.join(args.out_dir, "01_gt_reset.sdf"))
random_trans = generate_random_trans(
torch.tensor([0, 0, 0], dtype=torch.float32), torch.tensor([3, 3, 3], dtype=torch.float32)
)
random_rot = generate_rotation_matrix(generate_random_quaternion())
random_torsion = generate_random_torsion_angles(sample.bond_endpoints.shape[0])
random_state = RigidTorsionState(center=random_trans, rot=random_rot, torsion=random_torsion)
vis = deepcopy(sample)
vis.coords = vis.gt_coords.clone()
apply_trans(vis, random_trans)
apply_rot_matrix(vis, random_rot)
apply_torsion(vis, random_torsion, vis.movable_mask, vis.bond_endpoints)
save_visualization(vis, os.path.join(args.out_dir, "02_random_init.sdf"))
velocity_gt = rigid_torsion_velocity(random_state, target_state, torch.tensor(0.0))
state_list_gt = [get_state(random_state, velocity_gt, float(t)) for t in torch.linspace(0, 1, 50)]
save_video_visualization(sample, state_list_gt, os.path.join(args.out_dir, "03_trajectory_gt_velocity.sdf"))
# Learned trajectory: Euler steps using model velocity (batch size 1)
model.eval()
state = RigidTorsionState(
center=random_state.center.to(device),
rot=random_state.rot.to(device),
torsion=random_state.torsion.to(device),
)
n_frames = 100
learned_states: list[RigidTorsionState] = []
for k in range(n_frames):
learned_states.append(
RigidTorsionState(
center=state.center.detach().cpu(),
rot=state.rot.detach().cpu(),
torsion=state.torsion.detach().cpu(),
)
)
t_val = k / max(1, n_frames - 1)
g = deepcopy(sample)
g.coords = sample.gt_coords.clone()
apply_trans(g, state.center.cpu())
apply_rot_matrix(g, state.rot.cpu())
apply_torsion(g, state.torsion.cpu(), g.movable_mask, g.bond_endpoints)
g = g.to(device)
batch = Batch.from_data_list([g]).to(device)
tb = torch.tensor([[t_val]], device=device, dtype=torch.float32)
with torch.no_grad():
v = model(batch, tb)
vc = v.center.squeeze(0)
vw = clip_vector_norm_batch(v.rot_omega, args.omega_max_norm).squeeze(0)
vt = v.torsion.squeeze(0)
dt = 1.0 / max(1, n_frames - 1)
state = RigidTorsionState(
center=state.center + vc * dt,
rot=state.rot @ so3_exp(vw * dt),
torsion=state.torsion + vt * dt,
)
save_video_visualization(sample, learned_states, os.path.join(args.out_dir, "04_trajectory_learned_model.sdf"))
print("Wrote SDFs under:", args.out_dir)
for name in sorted(os.listdir(args.out_dir)):
if name.endswith(".sdf"):
print(" ", name)
return 0
if __name__ == "__main__":
raise SystemExit(main())