Files
ai-rfm/train.py
demian3b 3f04a380d8 Improve geodesic training stability and update best eval baseline.
Make wrapped torsion loss mandatory, add configurable train-loss early stopping controls, and log new architecture/loss attempts; latest geodesic run improves mean_rmsd_100 to 2.429895 for mainline integration.

Made-with: Cursor
2026-04-16 22:20:36 +09:00

1145 lines
40 KiB
Python

#!/usr/bin/env python3
"""
Consolidated RFM toy training + visualization from sandbox_rfm.ipynb.
- DataLoader with num_workers + prefetch builds CPU batches in parallel; tensors moved to GPU for train.
- pin_memory defaults False (fewer shm/pin-thread issues); enable with --pin-memory if stable.
"""
from __future__ import annotations
import argparse
import json
import os
import pickle
import sys
import time
from collections import deque
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from rdkit import Chem
from rdkit.Chem import SDWriter
from rdkit.Geometry.rdGeometry import Point3D
from torch import nn
from torch_geometric.data import Batch, Data
from torch_geometric.nn import GCNConv, global_mean_pool
# ---------------------------------------------------------------------------
# Geometry / torsion (from notebook)
# ---------------------------------------------------------------------------
def generate_random_trans(mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
return torch.randn(mean.shape, device=mean.device, dtype=mean.dtype) * std + mean
def apply_trans(x: Data, t_vec: torch.Tensor) -> Data:
x.coords = x.coords + t_vec
return x
def generate_random_quaternion() -> torch.Tensor:
u1, u2, u3 = torch.rand(3)
q1 = (1 - u1).sqrt() * torch.sin(2 * torch.pi * u2)
q2 = (1 - u1).sqrt() * torch.cos(2 * torch.pi * u2)
q3 = u1.sqrt() * torch.sin(2 * torch.pi * u3)
q4 = u1.sqrt() * torch.cos(2 * torch.pi * u3)
return torch.stack([q1, q2, q3, q4], dim=-1)
def generate_rotation_matrix(q: torch.Tensor) -> torch.Tensor:
rot = torch.tensor(
[
[
1 - 2 * q[2] ** 2 - 2 * q[3] ** 2,
2 * (q[1] * q[2] - q[0] * q[3]),
2 * (q[1] * q[3] + q[0] * q[2]),
],
[
2 * (q[1] * q[2] + q[0] * q[3]),
1 - 2 * q[1] ** 2 - 2 * q[3] ** 2,
2 * (q[2] * q[3] - q[0] * q[1]),
],
[
2 * (q[1] * q[3] - q[0] * q[2]),
2 * (q[2] * q[3] + q[0] * q[1]),
1 - 2 * q[1] ** 2 - 2 * q[2] ** 2,
],
],
device=q.device,
dtype=q.dtype,
)
return rot
def apply_rot_matrix(x: Data, rot_matrix: torch.Tensor) -> Data:
center = torch.mean(x.coords, dim=0)
coords = (x.coords - center) @ rot_matrix.T
x.coords = coords + center
return x
def _rodrigues(
points: torch.Tensor,
pivot: torch.Tensor,
axis_unit: torch.Tensor,
angle: torch.Tensor,
) -> torch.Tensor:
v = points - pivot
c, s = torch.cos(angle), torch.sin(angle)
cross = torch.cross(axis_unit.expand_as(v), v, dim=-1)
dot_u = (v * axis_unit).sum(-1, keepdim=True)
return pivot + v * c + cross * s + axis_unit * (dot_u * (1 - c))
def apply_torsion(
x: Any,
angles: torch.Tensor,
movable_mask: torch.Tensor,
bond_endpoints: torch.Tensor,
) -> Any:
coords = x.coords
B = angles.shape[0]
bond_endpoints = bond_endpoints.to(device=coords.device, dtype=torch.long)
movable_mask = movable_mask.to(device=coords.device)
angles = angles.to(device=coords.device, dtype=coords.dtype).reshape(B)
for b in range(B):
i, j = int(bond_endpoints[b, 0]), int(bond_endpoints[b, 1])
m = movable_mask[b]
d = coords[j] - coords[i]
u = d / d.norm()
idx = m.nonzero(as_tuple=True)[0]
coords.index_put_((idx,), _rodrigues(coords[idx], coords[i], u, angles[b]))
return x
def generate_random_torsion_angles(
num_bonds: int,
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
) -> torch.Tensor:
t = torch.empty(num_bonds, dtype=dtype, device=device)
t.uniform_(0, 2 * torch.pi)
return t
def _movable_j_side(edge_index: torch.Tensor, i: int, j: int, num_atoms: int) -> torch.Tensor:
ei = edge_index[:2]
adj = [[] for _ in range(num_atoms)]
for e in range(ei.shape[1]):
u, v = int(ei[0, e]), int(ei[1, e])
if {u, v} == {i, j}:
continue
adj[u].append(v)
adj[v].append(u)
seen = {j}
q = deque([j])
while q:
u = q.popleft()
for v in adj[u]:
if v not in seen:
seen.add(v)
q.append(v)
m = torch.zeros(num_atoms, dtype=torch.bool)
m[list(seen)] = True
return m
def bfs_torsion_masks(x: Any, root: int = 0) -> tuple[torch.Tensor, torch.Tensor]:
rot_pat = Chem.MolFromSmarts("[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]")
rot_set = {frozenset((int(a), int(b))) for a, b in x.rdmol.GetSubstructMatches(rot_pat)}
N = int(x.coords.shape[0])
ei = x.edge_index
adj = [[] for _ in range(N)]
for e in range(ei.shape[1]):
u, v = int(ei[0, e]), int(ei[1, e])
adj[u].append(v)
adj[v].append(u)
for u in range(N):
adj[u].sort()
depth = [-1] * N
depth[root] = 0
q = deque([root])
ordered_ij: list[tuple[int, int]] = []
tree_pairs: set[frozenset[int]] = set()
while q:
u = q.popleft()
for v in adj[u]:
if depth[v] != -1:
continue
depth[v] = depth[u] + 1
q.append(v)
fs = frozenset((u, v))
tree_pairs.add(fs)
if fs in rot_set:
ordered_ij.append((u, v))
def orient(fs: frozenset[int]) -> tuple[int, int]:
a, b = tuple(fs)
da, db = depth[a], depth[b]
if da < db or (da == db and a < b):
return (a, b)
return (b, a)
for fs in sorted(rot_set - tree_pairs, key=orient):
ordered_ij.append(orient(fs))
bond_endpoints = torch.tensor(ordered_ij, dtype=torch.long)
movable_mask = torch.stack([_movable_j_side(ei, i, j, N) for i, j in ordered_ij])
return bond_endpoints, movable_mask
def vel_center(current: torch.Tensor, target: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return (target - current) / (1 - t).clamp(min=1e-6)
def wrap_angle(d: torch.Tensor) -> torch.Tensor:
return torch.atan2(torch.sin(d), torch.cos(d))
def vel_torsion(current: torch.Tensor, target: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return wrap_angle(target - current) / (1 - t).clamp(min=1e-6)
def torsion_wrapped_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return torch.mean(wrap_angle(pred - target) ** 2)
def so3_geodesic_mse_loss(pred_omega: torch.Tensor, target_omega: torch.Tensor) -> torch.Tensor:
"""Geodesic MSE between batched axis-angle increments."""
if pred_omega.ndim != 2 or pred_omega.shape[-1] != 3:
raise ValueError("pred_omega must be [B,3]")
losses = []
for i in range(pred_omega.shape[0]):
r_pred = so3_exp(pred_omega[i])
r_tgt = so3_exp(target_omega[i])
delta = so3_log(r_pred.T @ r_tgt)
losses.append(torch.sum(delta * delta))
return torch.stack(losses).mean()
def skew(w: torch.Tensor) -> torch.Tensor:
z = torch.zeros((), device=w.device, dtype=w.dtype)
return torch.stack(
[
torch.stack([z, -w[2], w[1]]),
torch.stack([w[2], z, -w[0]]),
torch.stack([-w[1], w[0], z]),
]
)
def so3_log(R: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
tr = (R.trace() - 1) * 0.5
tr = tr.clamp(-1 + eps, 1 - eps)
theta = torch.acos(tr)
w = torch.stack(
[R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]]
) / (2 * torch.sin(theta) + eps)
return w * theta
def so3_exp(omega: torch.Tensor) -> torch.Tensor:
theta = omega.norm()
w = omega / (theta + 1e-12)
K = skew(w)
I = torch.eye(3, device=omega.device, dtype=omega.dtype)
return I + torch.sin(theta) * K + (1 - torch.cos(theta)) * (K @ K)
def vel_rot_mat(current: torch.Tensor, target: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
om = so3_log(current.T @ target)
return om / (1 - t).clamp(min=1e-6)
@dataclass
class RigidTorsionState:
center: torch.Tensor
rot: torch.Tensor
torsion: torch.Tensor
@dataclass
class RigidTorsionVelocity:
center: torch.Tensor
rot_omega: torch.Tensor
torsion: torch.Tensor
def rigid_torsion_velocity(
current: RigidTorsionState, target: RigidTorsionState, t: torch.Tensor
) -> RigidTorsionVelocity:
t = torch.as_tensor(t)
vc = vel_center(current.center, target.center, t)
vt = vel_torsion(current.torsion, target.torsion, t)
omega = vel_rot_mat(current.rot, target.rot, t)
return RigidTorsionVelocity(vc, omega, vt)
def get_state(
current: RigidTorsionState, velocity: RigidTorsionVelocity, t: float
) -> RigidTorsionState:
t = torch.as_tensor(t)
center = current.center + velocity.center * t
rot = current.rot @ so3_exp(velocity.rot_omega * t)
torsion = current.torsion + velocity.torsion * t
return RigidTorsionState(center, rot, torsion)
# ---------------------------------------------------------------------------
# Visualization
# ---------------------------------------------------------------------------
def save_visualization(sample: Data, path: str) -> None:
mol = deepcopy(sample.rdmol)
conformer = mol.GetConformer()
coords = sample.coords
for i, coord in enumerate(coords):
x, y, z = coord.tolist()
conformer.SetAtomPosition(i, Point3D(x, y, z))
Chem.MolToMolFile(mol, path)
def save_video_visualization(
sample: Data,
states: list[RigidTorsionState],
path: str,
) -> None:
mol = deepcopy(sample.rdmol)
sample.coords = sample.gt_coords.clone()
with open(path, "w") as f:
writer = SDWriter(f)
for state in states:
sample_copy = deepcopy(sample)
apply_trans(sample_copy, state.center)
apply_rot_matrix(sample_copy, state.rot)
apply_torsion(sample_copy, state.torsion, sample_copy.movable_mask, sample_copy.bond_endpoints)
coords = sample_copy.coords
conformer = mol.GetConformer()
for i, coord in enumerate(coords):
x, y, z = coord.tolist()
conformer.SetAtomPosition(i, Point3D(x, y, z))
writer.write(mol)
writer.close()
# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------
def kabsch_rmsd(pred: torch.Tensor, target: torch.Tensor) -> float:
"""RMSD after rigid alignment (Kabsch), shape: [N, 3]."""
p = pred - pred.mean(dim=0, keepdim=True)
q = target - target.mean(dim=0, keepdim=True)
h = p.T @ q
u, _, v_t = torch.linalg.svd(h)
r = v_t.T @ u.T
if torch.det(r) < 0:
v_t[-1, :] *= -1
r = v_t.T @ u.T
p_aligned = p @ r
return float(torch.sqrt(torch.mean(torch.sum((p_aligned - q) ** 2, dim=-1))).item())
def graph_from_state_cpu(sample: Data, center: torch.Tensor, rot: torch.Tensor, torsion: torch.Tensor) -> Data:
g = deepcopy(sample)
g.coords = sample.gt_coords.clone()
apply_trans(g, center)
apply_rot_matrix(g, rot)
apply_torsion(g, torsion, g.movable_mask, g.bond_endpoints)
return g
def evaluate_mean_rmsd_100(
model: nn.Module,
sample: Data,
target_coords: torch.Tensor,
device: torch.device,
num_runs: int = 100,
rollout_steps: int = 100,
) -> float:
"""Run 100 one-shot predictions to time=1 and average final-structure RMSD."""
del rollout_steps
model.eval()
nb = sample.bond_endpoints.shape[0]
centers = []
rots = []
torsions = []
for _ in range(num_runs):
centers.append(generate_random_trans(torch.zeros(3), torch.tensor([3.0, 3.0, 3.0])))
rots.append(generate_rotation_matrix(generate_random_quaternion()))
torsions.append(generate_random_torsion_angles(nb))
center_t = torch.stack(centers, dim=0) # [R,3]
rot_t = torch.stack(rots, dim=0) # [R,3,3]
torsion_t = torch.stack(torsions, dim=0) # [R,B]
graphs = [
graph_from_state_cpu(sample, center_t[i], rot_t[i], torsion_t[i])
for i in range(num_runs)
]
batch = Batch.from_data_list(graphs).to(device)
tb = torch.zeros((num_runs, 1), device=device, dtype=torch.float32)
with torch.no_grad():
v = model(batch, tb)
vc = v.center.detach().cpu()
vw = v.rot_omega.detach().cpu()
vt = v.torsion.detach().cpu()
center_t = center_t + vc
torsion_t = torsion_t + vt
for i in range(num_runs):
rot_t[i] = rot_t[i] @ so3_exp(vw[i])
rmsds = []
for i in range(num_runs):
final_graph = graph_from_state_cpu(sample, center_t[i], rot_t[i], torsion_t[i])
rmsds.append(kabsch_rmsd(final_graph.coords, target_coords))
return float(sum(rmsds) / len(rmsds))
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class RFMModel(nn.Module):
"""Model for velocity prediction. Supports GCN and fixed-order MLP encoders."""
def __init__(
self,
model_type: str,
num_nodes: int,
num_torsions: int,
hidden: int = 128,
num_layers: int = 4,
gcn_residual: bool = False,
channel_layernorm: bool = False,
head_mlp_layers: int = 1,
):
super().__init__()
self.model_type = model_type
self.num_nodes = num_nodes
self.num_torsions = num_torsions
self.hidden = hidden
self.num_layers = num_layers
self.gcn_residual = gcn_residual
self.channel_layernorm = channel_layernorm
self.act = nn.SiLU()
self.head_mlp_layers = max(1, int(head_mlp_layers))
if model_type == "gcn":
self.convs = nn.ModuleList()
self.convs.append(GCNConv(3, hidden))
for _ in range(num_layers - 1):
self.convs.append(GCNConv(hidden, hidden))
self.gcn_skip = nn.ModuleList([nn.Identity()])
for _ in range(num_layers - 1):
self.gcn_skip.append(nn.Identity())
trunk_dim = hidden
elif model_type == "mlp":
mlp_layers = []
in_dim = num_nodes * 3
for _ in range(num_layers):
mlp_layers.append(nn.Linear(in_dim, hidden))
mlp_layers.append(nn.SiLU())
in_dim = hidden
self.encoder = nn.Sequential(*mlp_layers)
trunk_dim = hidden
else:
raise ValueError(f"Unsupported model_type: {model_type}")
self.time_feat_dim = 3
d = trunk_dim + self.time_feat_dim
self.shared_norm = nn.LayerNorm(d) if channel_layernorm else nn.Identity()
def make_head(out_dim: int) -> nn.Sequential:
layers: list[nn.Module] = []
in_dim = d
for _ in range(self.head_mlp_layers - 1):
layers.append(nn.Linear(in_dim, d))
if channel_layernorm:
layers.append(nn.LayerNorm(d))
layers.append(nn.SiLU())
in_dim = d
layers.append(nn.Linear(in_dim, out_dim))
return nn.Sequential(*layers)
self.trans_head = make_head(3)
self.rot_head = make_head(3)
self.torsion_head = make_head(num_torsions)
def forward(self, x: Data, time: torch.Tensor) -> RigidTorsionVelocity:
if self.model_type == "gcn":
h = x.coords
edge_index = x.edge_index
batch = x.batch
for idx, conv in enumerate(self.convs):
h_new = self.act(conv(h, edge_index))
if self.gcn_residual and h_new.shape == h.shape:
h = h_new + self.gcn_skip[idx](h)
else:
h = h_new
trunk = global_mean_pool(h, batch)
else:
bsz = time.shape[0]
coords = x.coords.reshape(bsz, self.num_nodes * 3)
trunk = self.encoder(coords)
t = time.to(dtype=trunk.dtype, device=trunk.device)
two_pi_t = 2.0 * torch.pi * t
time_feats = torch.cat([t, torch.sin(two_pi_t), torch.cos(two_pi_t)], dim=-1)
pooled = self.shared_norm(torch.cat([trunk, time_feats], dim=-1))
trans = self.trans_head(pooled)
rot_omega = self.rot_head(pooled)
torsion = self.torsion_head(pooled)
return RigidTorsionVelocity(center=trans, rot_omega=rot_omega, torsion=torsion)
# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def load_sample_from_pickle(path: str, pos: int) -> Data:
with open(path, "rb") as f:
f.seek(pos)
raw = pickle.load(f)
raw.edge_index = raw.edge_index[:2, :]
d = Data(coords=raw.coords, edge_index=raw.edge_index)
d.num_nodes = int(raw.coords.shape[0])
d.rdmol = raw.rdmol
for name in ("gt_coords", "gt_trans", "gt_quaternion", "gt_torsion", "bond_endpoints", "movable_mask"):
if hasattr(raw, name):
setattr(d, name, getattr(raw, name))
return d
def load_sample_from_sdf(sdf_path: str) -> Data:
mol = Chem.MolFromMolFile(sdf_path, removeHs=False)
if mol is None:
raise RuntimeError(f"RDKit could not read {sdf_path}")
conf = mol.GetConformer()
n = mol.GetNumAtoms()
coords = torch.tensor(
[[conf.GetAtomPosition(i).x, conf.GetAtomPosition(i).y, conf.GetAtomPosition(i).z] for i in range(n)],
dtype=torch.float32,
)
edges: list[tuple[int, int]] = []
for b in mol.GetBonds():
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
edges.append((i, j))
edges.append((j, i))
edge_index = torch.tensor(edges, dtype=torch.long).T.contiguous()
data = Data()
data.coords = coords
data.edge_index = edge_index
data.num_nodes = int(coords.shape[0])
data.rdmol = Chem.Mol(mol)
return data
def prepare_sample(sample: Data) -> Data:
sample.num_nodes = int(sample.coords.shape[0])
sample.gt_coords = sample.coords.clone()
sample.gt_trans = torch.mean(sample.coords, dim=0, dtype=torch.float32)
sample.gt_quaternion = torch.tensor((0, 0, 0, 0), dtype=torch.float32)
sample.bond_endpoints, sample.movable_mask = bfs_torsion_masks(sample)
sample.gt_torsion = torch.zeros(sample.bond_endpoints.shape[0], dtype=torch.float32)
return sample
# ---------------------------------------------------------------------------
# CPU batch building (DataLoader workers or main process)
# ---------------------------------------------------------------------------
def _gt_state_cpu(gs: RigidTorsionState) -> RigidTorsionState:
return RigidTorsionState(
center=gs.center.cpu(),
rot=gs.rot.cpu(),
torsion=gs.torsion.cpu(),
)
def make_rfm_example(
sample: Data,
gt_state_cpu: RigidTorsionState,
cpu: torch.device,
time_power: float = 1.0,
) -> dict[str, Any]:
"""One random pose + velocity targets (CPU tensors). Safe for DataLoader workers."""
random_trans = generate_random_trans(
torch.tensor([0, 0, 0], dtype=torch.float32, device=cpu),
torch.tensor([3, 3, 3], dtype=torch.float32, device=cpu),
)
random_rot = generate_rotation_matrix(generate_random_quaternion())
random_torsion = generate_random_torsion_angles(sample.bond_endpoints.shape[0], device=cpu)
cur = RigidTorsionState(center=random_trans, rot=random_rot, torsion=random_torsion)
velocity = rigid_torsion_velocity(cur, gt_state_cpu, torch.tensor(0.0, device=cpu))
t_scalar = torch.rand((), device=cpu) ** time_power
new_state = get_state(cur, velocity, float(t_scalar.item()))
g = deepcopy(sample)
g.coords = g.coords.to(device=cpu, dtype=torch.float32)
g.edge_index = g.edge_index.to(device=cpu)
if hasattr(g, "movable_mask"):
g.movable_mask = g.movable_mask.to(device=cpu)
if hasattr(g, "bond_endpoints"):
g.bond_endpoints = g.bond_endpoints.to(device=cpu)
apply_trans(g, new_state.center)
apply_rot_matrix(g, new_state.rot)
apply_torsion(g, new_state.torsion, g.movable_mask, g.bond_endpoints)
return {
"data": g,
"v_center": velocity.center.detach().cpu(),
"v_omega": velocity.rot_omega.detach().cpu(),
"v_torsion": velocity.torsion.detach().cpu(),
"time": t_scalar.detach().cpu(),
}
def rfm_collate(batch: list[dict[str, Any]]) -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
graphs = [b["data"] for b in batch]
batched = Batch.from_data_list(graphs)
time_b = torch.stack([b["time"] for b in batch], dim=0).unsqueeze(-1)
target_center = torch.stack([b["v_center"] for b in batch], dim=0)
target_omega = torch.stack([b["v_omega"] for b in batch], dim=0)
target_torsion = torch.stack([b["v_torsion"] for b in batch], dim=0)
return batched, time_b, target_center, target_omega, target_torsion
class SameGraphRandomRFMDataset(Dataset):
"""Same topology (`sample`); each __getitem__ is a new random RFM instance (CPU)."""
def __init__(
self,
sample: Data,
gt_state_cpu: RigidTorsionState,
length: int,
time_power: float = 1.0,
):
self.sample = sample
self.gt_state = _gt_state_cpu(gt_state_cpu)
self.length = length
self.time_power = time_power
def __len__(self) -> int:
return self.length
def __getitem__(self, idx: int) -> dict[str, Any]:
del idx
return make_rfm_example(
self.sample,
self.gt_state,
torch.device("cpu"),
self.time_power,
)
def build_one_batch(
sample: Data,
gt_state: RigidTorsionState,
batch_size: int,
cpu: torch.device,
time_power: float = 1.0,
) -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Single-process batch (no workers); same tensors as DataLoader+collate."""
rows = [
make_rfm_example(sample, gt_state, cpu, time_power=time_power)
for _ in range(batch_size)
]
return rfm_collate(rows)
def infinite_batches(loader: DataLoader):
while True:
for batch in loader:
yield batch
def main() -> int:
here = os.path.dirname(os.path.abspath(__file__))
parser = argparse.ArgumentParser()
parser.add_argument("--pickle-path", default="", help="Optional pickle shard path")
parser.add_argument("--pickle-pos", type=int, default=61908729)
parser.add_argument("--sdf", default=os.path.join(here, "sample.sdf"), help="Fallback ligand SDF")
parser.add_argument("--out-dir", default=os.path.join(here, "rfm_out"))
parser.add_argument("--epochs", type=int, default=2000)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--weight-center", type=float, default=1.0, help="Loss weight for translation velocity.")
parser.add_argument("--weight-omega", type=float, default=2.0, help="Loss weight for rotation velocity.")
parser.add_argument("--weight-torsion", type=float, default=4.0, help="Loss weight for torsion velocity.")
parser.add_argument("--grad-clip", type=float, default=1.0, help="Gradient clipping max norm.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hidden", type=int, default=128, help="GCN hidden dim (larger => more GPU work)")
parser.add_argument(
"--model-type",
type=str,
default="mlp",
choices=["mlp", "gcn"],
help="Backbone encoder for velocity prediction.",
)
parser.add_argument(
"--gcn-layers",
type=int,
default=4,
help="Number of GCNConv layers (more => longer GPU kernels)",
)
parser.add_argument(
"--accum",
type=int,
default=1,
help="Gradient accumulation steps per optimizer step (keeps GPU busier vs CPU batch build)",
)
parser.add_argument(
"--compile",
action="store_true",
help="torch.compile(model) when supported (PyTorch 2+)",
)
parser.add_argument(
"--num-workers",
type=int,
default=-1,
help="DataLoader workers; -1 = min(8, CPU-1), 0 = main process only",
)
parser.add_argument(
"--prefetch-factor",
type=int,
default=4,
help="Per-worker prefetch queue depth (ignored if num-workers=0)",
)
parser.add_argument(
"--pin-memory",
action="store_true",
help="DataLoader pin_memory (faster H2D; can hit shm/pin issues in some setups)",
)
parser.add_argument(
"--sync-batch",
action="store_true",
help="Build batches in main thread (no DataLoader workers)",
)
parser.add_argument(
"--dataset-len",
type=int,
default=1_000_000,
help="Synthetic Dataset length (only affects shuffle / epoch size)",
)
parser.add_argument(
"--eval-runs",
type=int,
default=100,
help="Number of rollout tests for final RMSD report",
)
parser.add_argument(
"--rollout-steps",
type=int,
default=100,
help="Integration steps per rollout during final test",
)
parser.add_argument(
"--time-power",
type=float,
default=1.0,
help="Sample t as U(0,1)^time_power; >1 biases training toward t~0 states.",
)
parser.add_argument(
"--loss-domain",
type=str,
default="displacement",
choices=["velocity", "displacement"],
help="Train on raw velocity or displacement-to-target (v * (1 - t)).",
)
parser.add_argument(
"--rotation-loss",
type=str,
default="mse",
choices=["mse", "geodesic"],
help="Rotation-channel loss type.",
)
parser.add_argument(
"--gcn-residual",
action="store_true",
help="Enable residual skip in GCN trunk.",
)
parser.add_argument(
"--channel-layernorm",
action="store_true",
help="Apply layernorm to pooled trunk and head hidden layers.",
)
parser.add_argument(
"--head-mlp-layers",
type=int,
default=1,
help="Per-channel head depth (>=1).",
)
parser.add_argument(
"--early-stop-patience",
type=int,
default=0,
help="Train-loss early stop patience in checks (0 disables).",
)
parser.add_argument(
"--early-stop-min-delta",
type=float,
default=0.0,
help="Minimum train-loss improvement required to reset patience.",
)
parser.add_argument(
"--early-stop-check-every",
type=int,
default=1,
help="Evaluate early-stop condition every N epochs.",
)
parser.add_argument(
"--early-stop-warmup",
type=int,
default=0,
help="Skip early-stop checks for first N epochs.",
)
args = parser.parse_args()
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu = torch.device("cpu")
if args.pickle_path and os.path.isfile(args.pickle_path):
print("Loading sample from pickle:", args.pickle_path, "pos", args.pickle_pos)
sample = load_sample_from_pickle(args.pickle_path, args.pickle_pos)
elif os.path.isfile(args.sdf):
print("Loading sample from SDF:", args.sdf)
sample = load_sample_from_sdf(args.sdf)
else:
print("No pickle or SDF found. Provide --pickle-path or --sdf", file=sys.stderr)
return 1
sample = prepare_sample(sample)
os.makedirs(args.out_dir, exist_ok=True)
save_visualization(sample, os.path.join(args.out_dir, "00_gt_coords.sdf"))
gt_state = RigidTorsionState(
center=torch.tensor([0, 0, 0], dtype=torch.float32, device=cpu),
rot=generate_rotation_matrix(torch.tensor([0, 0, 0, 0], dtype=torch.float32)),
torsion=torch.zeros(sample.bond_endpoints.shape[0], dtype=torch.float32),
)
target_state = RigidTorsionState(
center=torch.tensor([0, 0, 0], dtype=torch.float32, device=cpu),
rot=generate_rotation_matrix(sample.gt_quaternion),
torsion=sample.gt_torsion,
)
model = RFMModel(
model_type=args.model_type,
num_nodes=int(sample.num_nodes),
num_torsions=int(sample.bond_endpoints.shape[0]),
hidden=args.hidden,
num_layers=args.gcn_layers,
gcn_residual=args.gcn_residual,
channel_layernorm=args.channel_layernorm,
head_mlp_layers=args.head_mlp_layers,
).to(device)
if args.compile and hasattr(torch, "compile"):
model = torch.compile(model) # type: ignore[assignment]
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(
f"device={device} model={args.model_type} hidden={args.hidden} "
f"layers={args.gcn_layers} params={n_params:,}"
)
print(f"time sampling power={args.time_power}")
print(
"loss weights: "
f"center={args.weight_center} omega={args.weight_omega} torsion={args.weight_torsion}; "
f"grad_clip={args.grad_clip}"
)
print("loss domain=%s torsion_wrapped_loss=always_on" % args.loss_domain)
print(
f"rotation_loss={args.rotation_loss} gcn_residual={args.gcn_residual} "
f"channel_layernorm={args.channel_layernorm} head_mlp_layers={args.head_mlp_layers}"
)
print(
"early_stop: "
f"patience={args.early_stop_patience} min_delta={args.early_stop_min_delta} "
f"check_every={args.early_stop_check_every} warmup={args.early_stop_warmup}"
)
nw = args.num_workers
if nw < 0:
nw = min(8, max(0, (os.cpu_count() or 4) - 1))
pf = args.prefetch_factor if nw > 0 else None
persistent = nw > 0
if args.sync_batch:
def get_train_batch() -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return build_one_batch(
sample,
gt_state,
args.batch_size,
cpu,
args.time_power,
)
print(f"batching: sync (main thread) workers=0")
else:
dataset = SameGraphRandomRFMDataset(
sample,
gt_state_cpu=gt_state,
length=args.dataset_len,
time_power=args.time_power,
)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=nw,
prefetch_factor=pf,
persistent_workers=persistent,
pin_memory=args.pin_memory and torch.cuda.is_available(),
collate_fn=rfm_collate,
)
_batch_iter = infinite_batches(loader)
def get_train_batch() -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return next(_batch_iter)
print(
f"batching: DataLoader num_workers={nw} prefetch_factor={pf!s} "
f"pin_memory={args.pin_memory and torch.cuda.is_available()} "
f"persistent_workers={persistent}"
)
t0 = time.perf_counter()
log_every = max(1, args.epochs // 10)
accum = max(1, args.accum)
best_train_mse = float("inf")
best_state: dict[str, torch.Tensor] | None = None
early_stop_best = float("inf")
early_stop_bad_checks = 0
stop_reason = "max_epochs"
stop_epoch = args.epochs - 1
for epoch in range(args.epochs):
optimizer.zero_grad(set_to_none=True)
sum_mse = torch.zeros((), device=device)
for _ in range(accum):
batched, time_b, tc, tw, tt = get_train_batch()
batched = batched.to(device, non_blocking=True)
time_b = time_b.to(device=device, dtype=torch.float32, non_blocking=True)
tc = tc.to(device=device, dtype=torch.float32, non_blocking=True)
tw = tw.to(device=device, dtype=torch.float32, non_blocking=True)
tt = tt.to(device=device, dtype=torch.float32, non_blocking=True)
pred = model(batched, time_b)
time_remain = (1.0 - time_b).clamp(min=1e-6)
if args.loss_domain == "displacement":
pred_center = pred.center * time_remain
pred_omega = pred.rot_omega * time_remain
pred_torsion = pred.torsion * time_remain
tgt_center = tc * time_remain
tgt_omega = tw * time_remain
tgt_torsion = tt * time_remain
else:
pred_center = pred.center
pred_omega = pred.rot_omega
pred_torsion = pred.torsion
tgt_center = tc
tgt_omega = tw
tgt_torsion = tt
eps = 1e-8
scale_c = (tgt_center * tgt_center).mean().clamp(min=eps)
scale_w = (tgt_omega * tgt_omega).mean().clamp(min=eps)
scale_t = (tgt_torsion * tgt_torsion).mean().clamp(min=eps)
if args.rotation_loss == "geodesic":
rot_loss = so3_geodesic_mse_loss(pred_omega, tgt_omega)
else:
rot_loss = F.mse_loss(pred_omega, tgt_omega)
mse = (
args.weight_center * (F.mse_loss(pred_center, tgt_center) / scale_c)
+ args.weight_omega * (rot_loss / scale_w)
+ args.weight_torsion
* (
torsion_wrapped_mse_loss(pred_torsion, tgt_torsion)
/ scale_t
)
)
(mse / accum).backward()
sum_mse = sum_mse + mse.detach()
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
avg_mse = sum_mse / accum
avg_mse_val = float(avg_mse.detach().cpu())
if avg_mse_val < best_train_mse:
best_train_mse = avg_mse_val
core = model._orig_mod if hasattr(model, "_orig_mod") else model
best_state = {k: v.detach().cpu().clone() for k, v in core.state_dict().items()}
should_check_early_stop = (
args.early_stop_patience > 0
and epoch >= args.early_stop_warmup
and ((epoch - args.early_stop_warmup) % max(1, args.early_stop_check_every) == 0)
)
if should_check_early_stop:
improved = avg_mse_val < (early_stop_best - args.early_stop_min_delta)
if improved:
early_stop_best = avg_mse_val
early_stop_bad_checks = 0
else:
early_stop_bad_checks += 1
if early_stop_bad_checks >= args.early_stop_patience:
stop_reason = (
"early_stop_train_loss_plateau"
f"(patience={args.early_stop_patience},"
f"min_delta={args.early_stop_min_delta},"
f"check_every={max(1, args.early_stop_check_every)})"
)
stop_epoch = epoch
print(
f"early stopping at epoch {epoch}: "
f"bad_checks={early_stop_bad_checks}"
)
break
if epoch % log_every == 0 or epoch == args.epochs - 1:
print(f"epoch {epoch:6d} train_mse {avg_mse_val:.6f} best_train_mse {best_train_mse:.6f}")
stop_epoch = epoch
if device.type == "cuda":
torch.cuda.synchronize()
print(
f"training done in {time.perf_counter() - t0:.1f}s "
f"(stop_reason={stop_reason}, stop_epoch={stop_epoch})"
)
ckpt_path = ""
if best_state is not None:
core = model._orig_mod if hasattr(model, "_orig_mod") else model
core.load_state_dict(best_state)
print(f"loaded best model from training (best_train_mse={best_train_mse:.6f})")
artifacts_dir = os.path.join(here, "artifacts")
os.makedirs(artifacts_dir, exist_ok=True)
ckpt_path = os.path.join(artifacts_dir, "latest_eval_best_model.pt")
ckpt = {
"state_dict": best_state,
"model_type": args.model_type,
"hidden": int(args.hidden),
"num_layers": int(args.gcn_layers),
"num_nodes": int(sample.num_nodes),
"num_torsions": int(sample.bond_endpoints.shape[0]),
"sample_sdf": args.sdf,
"sample_pickle_path": args.pickle_path,
"sample_pickle_pos": int(args.pickle_pos),
"eval_runs": int(args.eval_runs),
"rollout_steps": int(args.rollout_steps),
}
torch.save(ckpt, ckpt_path)
print(f"saved best checkpoint: {ckpt_path}")
# Final test metric: 100 rollouts mean final-structure RMSD
final_mean_rmsd_100 = evaluate_mean_rmsd_100(
model,
sample,
sample.gt_coords,
device=device,
num_runs=args.eval_runs,
rollout_steps=args.rollout_steps,
)
print(f"final_test_mean_rmsd_{args.eval_runs}: {final_mean_rmsd_100:.6f}")
reports_dir = os.path.join(here, "reports")
os.makedirs(reports_dir, exist_ok=True)
latest_path = os.path.join(reports_dir, "latest_eval.json")
timestamp = datetime.now(timezone.utc).isoformat()
cmd = " ".join(sys.argv)
latest_report = {
"mean_rmsd_100": float(final_mean_rmsd_100),
"num_runs": int(args.eval_runs),
"timestamp_utc": timestamp,
"command": cmd,
"notes": "Final test metric from 100 random-initialized rollouts to time=1.",
"best_train_mse": float(best_train_mse),
"model_source": "best_train_checkpoint",
"checkpoint_path": ckpt_path,
"stop_reason": stop_reason,
"stop_epoch": int(stop_epoch),
}
with open(latest_path, "w", encoding="utf-8") as f:
json.dump(latest_report, f, indent=2)
# --- Visualizations ---
sample.coords = sample.gt_coords.clone()
save_visualization(sample, os.path.join(args.out_dir, "01_gt_reset.sdf"))
random_trans = generate_random_trans(
torch.tensor([0, 0, 0], dtype=torch.float32), torch.tensor([3, 3, 3], dtype=torch.float32)
)
random_rot = generate_rotation_matrix(generate_random_quaternion())
random_torsion = generate_random_torsion_angles(sample.bond_endpoints.shape[0])
random_state = RigidTorsionState(center=random_trans, rot=random_rot, torsion=random_torsion)
vis = deepcopy(sample)
vis.coords = vis.gt_coords.clone()
apply_trans(vis, random_trans)
apply_rot_matrix(vis, random_rot)
apply_torsion(vis, random_torsion, vis.movable_mask, vis.bond_endpoints)
save_visualization(vis, os.path.join(args.out_dir, "02_random_init.sdf"))
velocity_gt = rigid_torsion_velocity(random_state, target_state, torch.tensor(0.0))
state_list_gt = [get_state(random_state, velocity_gt, float(t)) for t in torch.linspace(0, 1, 50)]
save_video_visualization(sample, state_list_gt, os.path.join(args.out_dir, "03_trajectory_gt_velocity.sdf"))
# Learned trajectory: Euler steps using model velocity (batch size 1)
model.eval()
state = RigidTorsionState(
center=random_state.center.to(device),
rot=random_state.rot.to(device),
torsion=random_state.torsion.to(device),
)
n_frames = 100
learned_states: list[RigidTorsionState] = []
for k in range(n_frames):
learned_states.append(
RigidTorsionState(
center=state.center.detach().cpu(),
rot=state.rot.detach().cpu(),
torsion=state.torsion.detach().cpu(),
)
)
t_val = k / max(1, n_frames - 1)
g = deepcopy(sample)
g.coords = sample.gt_coords.clone()
apply_trans(g, state.center.cpu())
apply_rot_matrix(g, state.rot.cpu())
apply_torsion(g, state.torsion.cpu(), g.movable_mask, g.bond_endpoints)
g = g.to(device)
batch = Batch.from_data_list([g]).to(device)
tb = torch.tensor([[t_val]], device=device, dtype=torch.float32)
with torch.no_grad():
v = model(batch, tb)
vc = v.center.squeeze(0)
vw = v.rot_omega.squeeze(0)
vt = v.torsion.squeeze(0)
dt = 1.0 / max(1, n_frames - 1)
state = RigidTorsionState(
center=state.center + vc * dt,
rot=state.rot @ so3_exp(vw * dt),
torsion=state.torsion + vt * dt,
)
save_video_visualization(sample, learned_states, os.path.join(args.out_dir, "04_trajectory_learned_model.sdf"))
print("Wrote SDFs under:", args.out_dir)
for name in sorted(os.listdir(args.out_dir)):
if name.endswith(".sdf"):
print(" ", name)
return 0
if __name__ == "__main__":
raise SystemExit(main())