Make wrapped torsion loss mandatory, add configurable train-loss early stopping controls, and log new architecture/loss attempts; latest geodesic run improves mean_rmsd_100 to 2.429895 for mainline integration. Made-with: Cursor
1145 lines
40 KiB
Python
1145 lines
40 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Consolidated RFM toy training + visualization from sandbox_rfm.ipynb.
|
|
- DataLoader with num_workers + prefetch builds CPU batches in parallel; tensors moved to GPU for train.
|
|
- pin_memory defaults False (fewer shm/pin-thread issues); enable with --pin-memory if stable.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import pickle
|
|
import sys
|
|
import time
|
|
from collections import deque
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from rdkit import Chem
|
|
from rdkit.Chem import SDWriter
|
|
from rdkit.Geometry.rdGeometry import Point3D
|
|
from torch import nn
|
|
from torch_geometric.data import Batch, Data
|
|
from torch_geometric.nn import GCNConv, global_mean_pool
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Geometry / torsion (from notebook)
|
|
# ---------------------------------------------------------------------------
|
|
def generate_random_trans(mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
|
return torch.randn(mean.shape, device=mean.device, dtype=mean.dtype) * std + mean
|
|
|
|
|
|
def apply_trans(x: Data, t_vec: torch.Tensor) -> Data:
|
|
x.coords = x.coords + t_vec
|
|
return x
|
|
|
|
|
|
def generate_random_quaternion() -> torch.Tensor:
|
|
u1, u2, u3 = torch.rand(3)
|
|
q1 = (1 - u1).sqrt() * torch.sin(2 * torch.pi * u2)
|
|
q2 = (1 - u1).sqrt() * torch.cos(2 * torch.pi * u2)
|
|
q3 = u1.sqrt() * torch.sin(2 * torch.pi * u3)
|
|
q4 = u1.sqrt() * torch.cos(2 * torch.pi * u3)
|
|
return torch.stack([q1, q2, q3, q4], dim=-1)
|
|
|
|
|
|
def generate_rotation_matrix(q: torch.Tensor) -> torch.Tensor:
|
|
rot = torch.tensor(
|
|
[
|
|
[
|
|
1 - 2 * q[2] ** 2 - 2 * q[3] ** 2,
|
|
2 * (q[1] * q[2] - q[0] * q[3]),
|
|
2 * (q[1] * q[3] + q[0] * q[2]),
|
|
],
|
|
[
|
|
2 * (q[1] * q[2] + q[0] * q[3]),
|
|
1 - 2 * q[1] ** 2 - 2 * q[3] ** 2,
|
|
2 * (q[2] * q[3] - q[0] * q[1]),
|
|
],
|
|
[
|
|
2 * (q[1] * q[3] - q[0] * q[2]),
|
|
2 * (q[2] * q[3] + q[0] * q[1]),
|
|
1 - 2 * q[1] ** 2 - 2 * q[2] ** 2,
|
|
],
|
|
],
|
|
device=q.device,
|
|
dtype=q.dtype,
|
|
)
|
|
return rot
|
|
|
|
|
|
def apply_rot_matrix(x: Data, rot_matrix: torch.Tensor) -> Data:
|
|
center = torch.mean(x.coords, dim=0)
|
|
coords = (x.coords - center) @ rot_matrix.T
|
|
x.coords = coords + center
|
|
return x
|
|
|
|
|
|
def _rodrigues(
|
|
points: torch.Tensor,
|
|
pivot: torch.Tensor,
|
|
axis_unit: torch.Tensor,
|
|
angle: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
v = points - pivot
|
|
c, s = torch.cos(angle), torch.sin(angle)
|
|
cross = torch.cross(axis_unit.expand_as(v), v, dim=-1)
|
|
dot_u = (v * axis_unit).sum(-1, keepdim=True)
|
|
return pivot + v * c + cross * s + axis_unit * (dot_u * (1 - c))
|
|
|
|
|
|
def apply_torsion(
|
|
x: Any,
|
|
angles: torch.Tensor,
|
|
movable_mask: torch.Tensor,
|
|
bond_endpoints: torch.Tensor,
|
|
) -> Any:
|
|
coords = x.coords
|
|
B = angles.shape[0]
|
|
bond_endpoints = bond_endpoints.to(device=coords.device, dtype=torch.long)
|
|
movable_mask = movable_mask.to(device=coords.device)
|
|
angles = angles.to(device=coords.device, dtype=coords.dtype).reshape(B)
|
|
for b in range(B):
|
|
i, j = int(bond_endpoints[b, 0]), int(bond_endpoints[b, 1])
|
|
m = movable_mask[b]
|
|
d = coords[j] - coords[i]
|
|
u = d / d.norm()
|
|
idx = m.nonzero(as_tuple=True)[0]
|
|
coords.index_put_((idx,), _rodrigues(coords[idx], coords[i], u, angles[b]))
|
|
return x
|
|
|
|
|
|
def generate_random_torsion_angles(
|
|
num_bonds: int,
|
|
dtype: torch.dtype = torch.float32,
|
|
device: torch.device | None = None,
|
|
) -> torch.Tensor:
|
|
t = torch.empty(num_bonds, dtype=dtype, device=device)
|
|
t.uniform_(0, 2 * torch.pi)
|
|
return t
|
|
|
|
|
|
def _movable_j_side(edge_index: torch.Tensor, i: int, j: int, num_atoms: int) -> torch.Tensor:
|
|
ei = edge_index[:2]
|
|
adj = [[] for _ in range(num_atoms)]
|
|
for e in range(ei.shape[1]):
|
|
u, v = int(ei[0, e]), int(ei[1, e])
|
|
if {u, v} == {i, j}:
|
|
continue
|
|
adj[u].append(v)
|
|
adj[v].append(u)
|
|
seen = {j}
|
|
q = deque([j])
|
|
while q:
|
|
u = q.popleft()
|
|
for v in adj[u]:
|
|
if v not in seen:
|
|
seen.add(v)
|
|
q.append(v)
|
|
m = torch.zeros(num_atoms, dtype=torch.bool)
|
|
m[list(seen)] = True
|
|
return m
|
|
|
|
|
|
def bfs_torsion_masks(x: Any, root: int = 0) -> tuple[torch.Tensor, torch.Tensor]:
|
|
rot_pat = Chem.MolFromSmarts("[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]")
|
|
rot_set = {frozenset((int(a), int(b))) for a, b in x.rdmol.GetSubstructMatches(rot_pat)}
|
|
N = int(x.coords.shape[0])
|
|
ei = x.edge_index
|
|
adj = [[] for _ in range(N)]
|
|
for e in range(ei.shape[1]):
|
|
u, v = int(ei[0, e]), int(ei[1, e])
|
|
adj[u].append(v)
|
|
adj[v].append(u)
|
|
for u in range(N):
|
|
adj[u].sort()
|
|
depth = [-1] * N
|
|
depth[root] = 0
|
|
q = deque([root])
|
|
ordered_ij: list[tuple[int, int]] = []
|
|
tree_pairs: set[frozenset[int]] = set()
|
|
while q:
|
|
u = q.popleft()
|
|
for v in adj[u]:
|
|
if depth[v] != -1:
|
|
continue
|
|
depth[v] = depth[u] + 1
|
|
q.append(v)
|
|
fs = frozenset((u, v))
|
|
tree_pairs.add(fs)
|
|
if fs in rot_set:
|
|
ordered_ij.append((u, v))
|
|
|
|
def orient(fs: frozenset[int]) -> tuple[int, int]:
|
|
a, b = tuple(fs)
|
|
da, db = depth[a], depth[b]
|
|
if da < db or (da == db and a < b):
|
|
return (a, b)
|
|
return (b, a)
|
|
|
|
for fs in sorted(rot_set - tree_pairs, key=orient):
|
|
ordered_ij.append(orient(fs))
|
|
bond_endpoints = torch.tensor(ordered_ij, dtype=torch.long)
|
|
movable_mask = torch.stack([_movable_j_side(ei, i, j, N) for i, j in ordered_ij])
|
|
return bond_endpoints, movable_mask
|
|
|
|
|
|
def vel_center(current: torch.Tensor, target: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
|
return (target - current) / (1 - t).clamp(min=1e-6)
|
|
|
|
|
|
def wrap_angle(d: torch.Tensor) -> torch.Tensor:
|
|
return torch.atan2(torch.sin(d), torch.cos(d))
|
|
|
|
|
|
def vel_torsion(current: torch.Tensor, target: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
|
return wrap_angle(target - current) / (1 - t).clamp(min=1e-6)
|
|
|
|
|
|
def torsion_wrapped_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
return torch.mean(wrap_angle(pred - target) ** 2)
|
|
|
|
|
|
def so3_geodesic_mse_loss(pred_omega: torch.Tensor, target_omega: torch.Tensor) -> torch.Tensor:
|
|
"""Geodesic MSE between batched axis-angle increments."""
|
|
if pred_omega.ndim != 2 or pred_omega.shape[-1] != 3:
|
|
raise ValueError("pred_omega must be [B,3]")
|
|
losses = []
|
|
for i in range(pred_omega.shape[0]):
|
|
r_pred = so3_exp(pred_omega[i])
|
|
r_tgt = so3_exp(target_omega[i])
|
|
delta = so3_log(r_pred.T @ r_tgt)
|
|
losses.append(torch.sum(delta * delta))
|
|
return torch.stack(losses).mean()
|
|
|
|
|
|
def skew(w: torch.Tensor) -> torch.Tensor:
|
|
z = torch.zeros((), device=w.device, dtype=w.dtype)
|
|
return torch.stack(
|
|
[
|
|
torch.stack([z, -w[2], w[1]]),
|
|
torch.stack([w[2], z, -w[0]]),
|
|
torch.stack([-w[1], w[0], z]),
|
|
]
|
|
)
|
|
|
|
|
|
def so3_log(R: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
|
tr = (R.trace() - 1) * 0.5
|
|
tr = tr.clamp(-1 + eps, 1 - eps)
|
|
theta = torch.acos(tr)
|
|
w = torch.stack(
|
|
[R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]]
|
|
) / (2 * torch.sin(theta) + eps)
|
|
return w * theta
|
|
|
|
|
|
def so3_exp(omega: torch.Tensor) -> torch.Tensor:
|
|
theta = omega.norm()
|
|
w = omega / (theta + 1e-12)
|
|
K = skew(w)
|
|
I = torch.eye(3, device=omega.device, dtype=omega.dtype)
|
|
return I + torch.sin(theta) * K + (1 - torch.cos(theta)) * (K @ K)
|
|
|
|
|
|
def vel_rot_mat(current: torch.Tensor, target: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
|
om = so3_log(current.T @ target)
|
|
return om / (1 - t).clamp(min=1e-6)
|
|
|
|
|
|
@dataclass
|
|
class RigidTorsionState:
|
|
center: torch.Tensor
|
|
rot: torch.Tensor
|
|
torsion: torch.Tensor
|
|
|
|
|
|
@dataclass
|
|
class RigidTorsionVelocity:
|
|
center: torch.Tensor
|
|
rot_omega: torch.Tensor
|
|
torsion: torch.Tensor
|
|
|
|
|
|
def rigid_torsion_velocity(
|
|
current: RigidTorsionState, target: RigidTorsionState, t: torch.Tensor
|
|
) -> RigidTorsionVelocity:
|
|
t = torch.as_tensor(t)
|
|
vc = vel_center(current.center, target.center, t)
|
|
vt = vel_torsion(current.torsion, target.torsion, t)
|
|
omega = vel_rot_mat(current.rot, target.rot, t)
|
|
return RigidTorsionVelocity(vc, omega, vt)
|
|
|
|
|
|
def get_state(
|
|
current: RigidTorsionState, velocity: RigidTorsionVelocity, t: float
|
|
) -> RigidTorsionState:
|
|
t = torch.as_tensor(t)
|
|
center = current.center + velocity.center * t
|
|
rot = current.rot @ so3_exp(velocity.rot_omega * t)
|
|
torsion = current.torsion + velocity.torsion * t
|
|
return RigidTorsionState(center, rot, torsion)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Visualization
|
|
# ---------------------------------------------------------------------------
|
|
def save_visualization(sample: Data, path: str) -> None:
|
|
mol = deepcopy(sample.rdmol)
|
|
conformer = mol.GetConformer()
|
|
coords = sample.coords
|
|
for i, coord in enumerate(coords):
|
|
x, y, z = coord.tolist()
|
|
conformer.SetAtomPosition(i, Point3D(x, y, z))
|
|
Chem.MolToMolFile(mol, path)
|
|
|
|
|
|
def save_video_visualization(
|
|
sample: Data,
|
|
states: list[RigidTorsionState],
|
|
path: str,
|
|
) -> None:
|
|
mol = deepcopy(sample.rdmol)
|
|
sample.coords = sample.gt_coords.clone()
|
|
with open(path, "w") as f:
|
|
writer = SDWriter(f)
|
|
for state in states:
|
|
sample_copy = deepcopy(sample)
|
|
apply_trans(sample_copy, state.center)
|
|
apply_rot_matrix(sample_copy, state.rot)
|
|
apply_torsion(sample_copy, state.torsion, sample_copy.movable_mask, sample_copy.bond_endpoints)
|
|
coords = sample_copy.coords
|
|
conformer = mol.GetConformer()
|
|
for i, coord in enumerate(coords):
|
|
x, y, z = coord.tolist()
|
|
conformer.SetAtomPosition(i, Point3D(x, y, z))
|
|
writer.write(mol)
|
|
writer.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Evaluation
|
|
# ---------------------------------------------------------------------------
|
|
def kabsch_rmsd(pred: torch.Tensor, target: torch.Tensor) -> float:
|
|
"""RMSD after rigid alignment (Kabsch), shape: [N, 3]."""
|
|
p = pred - pred.mean(dim=0, keepdim=True)
|
|
q = target - target.mean(dim=0, keepdim=True)
|
|
h = p.T @ q
|
|
u, _, v_t = torch.linalg.svd(h)
|
|
r = v_t.T @ u.T
|
|
if torch.det(r) < 0:
|
|
v_t[-1, :] *= -1
|
|
r = v_t.T @ u.T
|
|
p_aligned = p @ r
|
|
return float(torch.sqrt(torch.mean(torch.sum((p_aligned - q) ** 2, dim=-1))).item())
|
|
|
|
|
|
def graph_from_state_cpu(sample: Data, center: torch.Tensor, rot: torch.Tensor, torsion: torch.Tensor) -> Data:
|
|
g = deepcopy(sample)
|
|
g.coords = sample.gt_coords.clone()
|
|
apply_trans(g, center)
|
|
apply_rot_matrix(g, rot)
|
|
apply_torsion(g, torsion, g.movable_mask, g.bond_endpoints)
|
|
return g
|
|
|
|
|
|
def evaluate_mean_rmsd_100(
|
|
model: nn.Module,
|
|
sample: Data,
|
|
target_coords: torch.Tensor,
|
|
device: torch.device,
|
|
num_runs: int = 100,
|
|
rollout_steps: int = 100,
|
|
) -> float:
|
|
"""Run 100 one-shot predictions to time=1 and average final-structure RMSD."""
|
|
del rollout_steps
|
|
model.eval()
|
|
nb = sample.bond_endpoints.shape[0]
|
|
centers = []
|
|
rots = []
|
|
torsions = []
|
|
for _ in range(num_runs):
|
|
centers.append(generate_random_trans(torch.zeros(3), torch.tensor([3.0, 3.0, 3.0])))
|
|
rots.append(generate_rotation_matrix(generate_random_quaternion()))
|
|
torsions.append(generate_random_torsion_angles(nb))
|
|
center_t = torch.stack(centers, dim=0) # [R,3]
|
|
rot_t = torch.stack(rots, dim=0) # [R,3,3]
|
|
torsion_t = torch.stack(torsions, dim=0) # [R,B]
|
|
|
|
graphs = [
|
|
graph_from_state_cpu(sample, center_t[i], rot_t[i], torsion_t[i])
|
|
for i in range(num_runs)
|
|
]
|
|
batch = Batch.from_data_list(graphs).to(device)
|
|
tb = torch.zeros((num_runs, 1), device=device, dtype=torch.float32)
|
|
with torch.no_grad():
|
|
v = model(batch, tb)
|
|
vc = v.center.detach().cpu()
|
|
vw = v.rot_omega.detach().cpu()
|
|
vt = v.torsion.detach().cpu()
|
|
center_t = center_t + vc
|
|
torsion_t = torsion_t + vt
|
|
for i in range(num_runs):
|
|
rot_t[i] = rot_t[i] @ so3_exp(vw[i])
|
|
|
|
rmsds = []
|
|
for i in range(num_runs):
|
|
final_graph = graph_from_state_cpu(sample, center_t[i], rot_t[i], torsion_t[i])
|
|
rmsds.append(kabsch_rmsd(final_graph.coords, target_coords))
|
|
return float(sum(rmsds) / len(rmsds))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Model
|
|
# ---------------------------------------------------------------------------
|
|
class RFMModel(nn.Module):
|
|
"""Model for velocity prediction. Supports GCN and fixed-order MLP encoders."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_type: str,
|
|
num_nodes: int,
|
|
num_torsions: int,
|
|
hidden: int = 128,
|
|
num_layers: int = 4,
|
|
gcn_residual: bool = False,
|
|
channel_layernorm: bool = False,
|
|
head_mlp_layers: int = 1,
|
|
):
|
|
super().__init__()
|
|
self.model_type = model_type
|
|
self.num_nodes = num_nodes
|
|
self.num_torsions = num_torsions
|
|
self.hidden = hidden
|
|
self.num_layers = num_layers
|
|
self.gcn_residual = gcn_residual
|
|
self.channel_layernorm = channel_layernorm
|
|
self.act = nn.SiLU()
|
|
self.head_mlp_layers = max(1, int(head_mlp_layers))
|
|
|
|
if model_type == "gcn":
|
|
self.convs = nn.ModuleList()
|
|
self.convs.append(GCNConv(3, hidden))
|
|
for _ in range(num_layers - 1):
|
|
self.convs.append(GCNConv(hidden, hidden))
|
|
self.gcn_skip = nn.ModuleList([nn.Identity()])
|
|
for _ in range(num_layers - 1):
|
|
self.gcn_skip.append(nn.Identity())
|
|
trunk_dim = hidden
|
|
elif model_type == "mlp":
|
|
mlp_layers = []
|
|
in_dim = num_nodes * 3
|
|
for _ in range(num_layers):
|
|
mlp_layers.append(nn.Linear(in_dim, hidden))
|
|
mlp_layers.append(nn.SiLU())
|
|
in_dim = hidden
|
|
self.encoder = nn.Sequential(*mlp_layers)
|
|
trunk_dim = hidden
|
|
else:
|
|
raise ValueError(f"Unsupported model_type: {model_type}")
|
|
|
|
self.time_feat_dim = 3
|
|
d = trunk_dim + self.time_feat_dim
|
|
self.shared_norm = nn.LayerNorm(d) if channel_layernorm else nn.Identity()
|
|
|
|
def make_head(out_dim: int) -> nn.Sequential:
|
|
layers: list[nn.Module] = []
|
|
in_dim = d
|
|
for _ in range(self.head_mlp_layers - 1):
|
|
layers.append(nn.Linear(in_dim, d))
|
|
if channel_layernorm:
|
|
layers.append(nn.LayerNorm(d))
|
|
layers.append(nn.SiLU())
|
|
in_dim = d
|
|
layers.append(nn.Linear(in_dim, out_dim))
|
|
return nn.Sequential(*layers)
|
|
|
|
self.trans_head = make_head(3)
|
|
self.rot_head = make_head(3)
|
|
self.torsion_head = make_head(num_torsions)
|
|
|
|
def forward(self, x: Data, time: torch.Tensor) -> RigidTorsionVelocity:
|
|
if self.model_type == "gcn":
|
|
h = x.coords
|
|
edge_index = x.edge_index
|
|
batch = x.batch
|
|
for idx, conv in enumerate(self.convs):
|
|
h_new = self.act(conv(h, edge_index))
|
|
if self.gcn_residual and h_new.shape == h.shape:
|
|
h = h_new + self.gcn_skip[idx](h)
|
|
else:
|
|
h = h_new
|
|
trunk = global_mean_pool(h, batch)
|
|
else:
|
|
bsz = time.shape[0]
|
|
coords = x.coords.reshape(bsz, self.num_nodes * 3)
|
|
trunk = self.encoder(coords)
|
|
|
|
t = time.to(dtype=trunk.dtype, device=trunk.device)
|
|
two_pi_t = 2.0 * torch.pi * t
|
|
time_feats = torch.cat([t, torch.sin(two_pi_t), torch.cos(two_pi_t)], dim=-1)
|
|
pooled = self.shared_norm(torch.cat([trunk, time_feats], dim=-1))
|
|
trans = self.trans_head(pooled)
|
|
rot_omega = self.rot_head(pooled)
|
|
torsion = self.torsion_head(pooled)
|
|
return RigidTorsionVelocity(center=trans, rot_omega=rot_omega, torsion=torsion)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Data loading
|
|
# ---------------------------------------------------------------------------
|
|
def load_sample_from_pickle(path: str, pos: int) -> Data:
|
|
with open(path, "rb") as f:
|
|
f.seek(pos)
|
|
raw = pickle.load(f)
|
|
raw.edge_index = raw.edge_index[:2, :]
|
|
d = Data(coords=raw.coords, edge_index=raw.edge_index)
|
|
d.num_nodes = int(raw.coords.shape[0])
|
|
d.rdmol = raw.rdmol
|
|
for name in ("gt_coords", "gt_trans", "gt_quaternion", "gt_torsion", "bond_endpoints", "movable_mask"):
|
|
if hasattr(raw, name):
|
|
setattr(d, name, getattr(raw, name))
|
|
return d
|
|
|
|
|
|
def load_sample_from_sdf(sdf_path: str) -> Data:
|
|
mol = Chem.MolFromMolFile(sdf_path, removeHs=False)
|
|
if mol is None:
|
|
raise RuntimeError(f"RDKit could not read {sdf_path}")
|
|
conf = mol.GetConformer()
|
|
n = mol.GetNumAtoms()
|
|
coords = torch.tensor(
|
|
[[conf.GetAtomPosition(i).x, conf.GetAtomPosition(i).y, conf.GetAtomPosition(i).z] for i in range(n)],
|
|
dtype=torch.float32,
|
|
)
|
|
edges: list[tuple[int, int]] = []
|
|
for b in mol.GetBonds():
|
|
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
|
|
edges.append((i, j))
|
|
edges.append((j, i))
|
|
edge_index = torch.tensor(edges, dtype=torch.long).T.contiguous()
|
|
data = Data()
|
|
data.coords = coords
|
|
data.edge_index = edge_index
|
|
data.num_nodes = int(coords.shape[0])
|
|
data.rdmol = Chem.Mol(mol)
|
|
return data
|
|
|
|
|
|
def prepare_sample(sample: Data) -> Data:
|
|
sample.num_nodes = int(sample.coords.shape[0])
|
|
sample.gt_coords = sample.coords.clone()
|
|
sample.gt_trans = torch.mean(sample.coords, dim=0, dtype=torch.float32)
|
|
sample.gt_quaternion = torch.tensor((0, 0, 0, 0), dtype=torch.float32)
|
|
sample.bond_endpoints, sample.movable_mask = bfs_torsion_masks(sample)
|
|
sample.gt_torsion = torch.zeros(sample.bond_endpoints.shape[0], dtype=torch.float32)
|
|
return sample
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CPU batch building (DataLoader workers or main process)
|
|
# ---------------------------------------------------------------------------
|
|
def _gt_state_cpu(gs: RigidTorsionState) -> RigidTorsionState:
|
|
return RigidTorsionState(
|
|
center=gs.center.cpu(),
|
|
rot=gs.rot.cpu(),
|
|
torsion=gs.torsion.cpu(),
|
|
)
|
|
|
|
|
|
def make_rfm_example(
|
|
sample: Data,
|
|
gt_state_cpu: RigidTorsionState,
|
|
cpu: torch.device,
|
|
time_power: float = 1.0,
|
|
) -> dict[str, Any]:
|
|
"""One random pose + velocity targets (CPU tensors). Safe for DataLoader workers."""
|
|
random_trans = generate_random_trans(
|
|
torch.tensor([0, 0, 0], dtype=torch.float32, device=cpu),
|
|
torch.tensor([3, 3, 3], dtype=torch.float32, device=cpu),
|
|
)
|
|
random_rot = generate_rotation_matrix(generate_random_quaternion())
|
|
random_torsion = generate_random_torsion_angles(sample.bond_endpoints.shape[0], device=cpu)
|
|
cur = RigidTorsionState(center=random_trans, rot=random_rot, torsion=random_torsion)
|
|
velocity = rigid_torsion_velocity(cur, gt_state_cpu, torch.tensor(0.0, device=cpu))
|
|
t_scalar = torch.rand((), device=cpu) ** time_power
|
|
new_state = get_state(cur, velocity, float(t_scalar.item()))
|
|
|
|
g = deepcopy(sample)
|
|
g.coords = g.coords.to(device=cpu, dtype=torch.float32)
|
|
g.edge_index = g.edge_index.to(device=cpu)
|
|
if hasattr(g, "movable_mask"):
|
|
g.movable_mask = g.movable_mask.to(device=cpu)
|
|
if hasattr(g, "bond_endpoints"):
|
|
g.bond_endpoints = g.bond_endpoints.to(device=cpu)
|
|
|
|
apply_trans(g, new_state.center)
|
|
apply_rot_matrix(g, new_state.rot)
|
|
apply_torsion(g, new_state.torsion, g.movable_mask, g.bond_endpoints)
|
|
|
|
return {
|
|
"data": g,
|
|
"v_center": velocity.center.detach().cpu(),
|
|
"v_omega": velocity.rot_omega.detach().cpu(),
|
|
"v_torsion": velocity.torsion.detach().cpu(),
|
|
"time": t_scalar.detach().cpu(),
|
|
}
|
|
|
|
|
|
def rfm_collate(batch: list[dict[str, Any]]) -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
graphs = [b["data"] for b in batch]
|
|
batched = Batch.from_data_list(graphs)
|
|
time_b = torch.stack([b["time"] for b in batch], dim=0).unsqueeze(-1)
|
|
target_center = torch.stack([b["v_center"] for b in batch], dim=0)
|
|
target_omega = torch.stack([b["v_omega"] for b in batch], dim=0)
|
|
target_torsion = torch.stack([b["v_torsion"] for b in batch], dim=0)
|
|
return batched, time_b, target_center, target_omega, target_torsion
|
|
|
|
|
|
class SameGraphRandomRFMDataset(Dataset):
|
|
"""Same topology (`sample`); each __getitem__ is a new random RFM instance (CPU)."""
|
|
|
|
def __init__(
|
|
self,
|
|
sample: Data,
|
|
gt_state_cpu: RigidTorsionState,
|
|
length: int,
|
|
time_power: float = 1.0,
|
|
):
|
|
self.sample = sample
|
|
self.gt_state = _gt_state_cpu(gt_state_cpu)
|
|
self.length = length
|
|
self.time_power = time_power
|
|
|
|
def __len__(self) -> int:
|
|
return self.length
|
|
|
|
def __getitem__(self, idx: int) -> dict[str, Any]:
|
|
del idx
|
|
return make_rfm_example(
|
|
self.sample,
|
|
self.gt_state,
|
|
torch.device("cpu"),
|
|
self.time_power,
|
|
)
|
|
|
|
|
|
def build_one_batch(
|
|
sample: Data,
|
|
gt_state: RigidTorsionState,
|
|
batch_size: int,
|
|
cpu: torch.device,
|
|
time_power: float = 1.0,
|
|
) -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Single-process batch (no workers); same tensors as DataLoader+collate."""
|
|
rows = [
|
|
make_rfm_example(sample, gt_state, cpu, time_power=time_power)
|
|
for _ in range(batch_size)
|
|
]
|
|
return rfm_collate(rows)
|
|
|
|
|
|
def infinite_batches(loader: DataLoader):
|
|
while True:
|
|
for batch in loader:
|
|
yield batch
|
|
|
|
|
|
def main() -> int:
|
|
here = os.path.dirname(os.path.abspath(__file__))
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--pickle-path", default="", help="Optional pickle shard path")
|
|
parser.add_argument("--pickle-pos", type=int, default=61908729)
|
|
parser.add_argument("--sdf", default=os.path.join(here, "sample.sdf"), help="Fallback ligand SDF")
|
|
parser.add_argument("--out-dir", default=os.path.join(here, "rfm_out"))
|
|
parser.add_argument("--epochs", type=int, default=2000)
|
|
parser.add_argument("--batch-size", type=int, default=32)
|
|
parser.add_argument("--lr", type=float, default=1e-3)
|
|
parser.add_argument("--weight-center", type=float, default=1.0, help="Loss weight for translation velocity.")
|
|
parser.add_argument("--weight-omega", type=float, default=2.0, help="Loss weight for rotation velocity.")
|
|
parser.add_argument("--weight-torsion", type=float, default=4.0, help="Loss weight for torsion velocity.")
|
|
parser.add_argument("--grad-clip", type=float, default=1.0, help="Gradient clipping max norm.")
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
parser.add_argument("--hidden", type=int, default=128, help="GCN hidden dim (larger => more GPU work)")
|
|
parser.add_argument(
|
|
"--model-type",
|
|
type=str,
|
|
default="mlp",
|
|
choices=["mlp", "gcn"],
|
|
help="Backbone encoder for velocity prediction.",
|
|
)
|
|
parser.add_argument(
|
|
"--gcn-layers",
|
|
type=int,
|
|
default=4,
|
|
help="Number of GCNConv layers (more => longer GPU kernels)",
|
|
)
|
|
parser.add_argument(
|
|
"--accum",
|
|
type=int,
|
|
default=1,
|
|
help="Gradient accumulation steps per optimizer step (keeps GPU busier vs CPU batch build)",
|
|
)
|
|
parser.add_argument(
|
|
"--compile",
|
|
action="store_true",
|
|
help="torch.compile(model) when supported (PyTorch 2+)",
|
|
)
|
|
parser.add_argument(
|
|
"--num-workers",
|
|
type=int,
|
|
default=-1,
|
|
help="DataLoader workers; -1 = min(8, CPU-1), 0 = main process only",
|
|
)
|
|
parser.add_argument(
|
|
"--prefetch-factor",
|
|
type=int,
|
|
default=4,
|
|
help="Per-worker prefetch queue depth (ignored if num-workers=0)",
|
|
)
|
|
parser.add_argument(
|
|
"--pin-memory",
|
|
action="store_true",
|
|
help="DataLoader pin_memory (faster H2D; can hit shm/pin issues in some setups)",
|
|
)
|
|
parser.add_argument(
|
|
"--sync-batch",
|
|
action="store_true",
|
|
help="Build batches in main thread (no DataLoader workers)",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-len",
|
|
type=int,
|
|
default=1_000_000,
|
|
help="Synthetic Dataset length (only affects shuffle / epoch size)",
|
|
)
|
|
parser.add_argument(
|
|
"--eval-runs",
|
|
type=int,
|
|
default=100,
|
|
help="Number of rollout tests for final RMSD report",
|
|
)
|
|
parser.add_argument(
|
|
"--rollout-steps",
|
|
type=int,
|
|
default=100,
|
|
help="Integration steps per rollout during final test",
|
|
)
|
|
parser.add_argument(
|
|
"--time-power",
|
|
type=float,
|
|
default=1.0,
|
|
help="Sample t as U(0,1)^time_power; >1 biases training toward t~0 states.",
|
|
)
|
|
parser.add_argument(
|
|
"--loss-domain",
|
|
type=str,
|
|
default="displacement",
|
|
choices=["velocity", "displacement"],
|
|
help="Train on raw velocity or displacement-to-target (v * (1 - t)).",
|
|
)
|
|
parser.add_argument(
|
|
"--rotation-loss",
|
|
type=str,
|
|
default="mse",
|
|
choices=["mse", "geodesic"],
|
|
help="Rotation-channel loss type.",
|
|
)
|
|
parser.add_argument(
|
|
"--gcn-residual",
|
|
action="store_true",
|
|
help="Enable residual skip in GCN trunk.",
|
|
)
|
|
parser.add_argument(
|
|
"--channel-layernorm",
|
|
action="store_true",
|
|
help="Apply layernorm to pooled trunk and head hidden layers.",
|
|
)
|
|
parser.add_argument(
|
|
"--head-mlp-layers",
|
|
type=int,
|
|
default=1,
|
|
help="Per-channel head depth (>=1).",
|
|
)
|
|
parser.add_argument(
|
|
"--early-stop-patience",
|
|
type=int,
|
|
default=0,
|
|
help="Train-loss early stop patience in checks (0 disables).",
|
|
)
|
|
parser.add_argument(
|
|
"--early-stop-min-delta",
|
|
type=float,
|
|
default=0.0,
|
|
help="Minimum train-loss improvement required to reset patience.",
|
|
)
|
|
parser.add_argument(
|
|
"--early-stop-check-every",
|
|
type=int,
|
|
default=1,
|
|
help="Evaluate early-stop condition every N epochs.",
|
|
)
|
|
parser.add_argument(
|
|
"--early-stop-warmup",
|
|
type=int,
|
|
default=0,
|
|
help="Skip early-stop checks for first N epochs.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
torch.manual_seed(args.seed)
|
|
if torch.cuda.is_available():
|
|
torch.backends.cudnn.benchmark = True
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
cpu = torch.device("cpu")
|
|
|
|
if args.pickle_path and os.path.isfile(args.pickle_path):
|
|
print("Loading sample from pickle:", args.pickle_path, "pos", args.pickle_pos)
|
|
sample = load_sample_from_pickle(args.pickle_path, args.pickle_pos)
|
|
elif os.path.isfile(args.sdf):
|
|
print("Loading sample from SDF:", args.sdf)
|
|
sample = load_sample_from_sdf(args.sdf)
|
|
else:
|
|
print("No pickle or SDF found. Provide --pickle-path or --sdf", file=sys.stderr)
|
|
return 1
|
|
|
|
sample = prepare_sample(sample)
|
|
os.makedirs(args.out_dir, exist_ok=True)
|
|
save_visualization(sample, os.path.join(args.out_dir, "00_gt_coords.sdf"))
|
|
|
|
gt_state = RigidTorsionState(
|
|
center=torch.tensor([0, 0, 0], dtype=torch.float32, device=cpu),
|
|
rot=generate_rotation_matrix(torch.tensor([0, 0, 0, 0], dtype=torch.float32)),
|
|
torsion=torch.zeros(sample.bond_endpoints.shape[0], dtype=torch.float32),
|
|
)
|
|
|
|
target_state = RigidTorsionState(
|
|
center=torch.tensor([0, 0, 0], dtype=torch.float32, device=cpu),
|
|
rot=generate_rotation_matrix(sample.gt_quaternion),
|
|
torsion=sample.gt_torsion,
|
|
)
|
|
|
|
model = RFMModel(
|
|
model_type=args.model_type,
|
|
num_nodes=int(sample.num_nodes),
|
|
num_torsions=int(sample.bond_endpoints.shape[0]),
|
|
hidden=args.hidden,
|
|
num_layers=args.gcn_layers,
|
|
gcn_residual=args.gcn_residual,
|
|
channel_layernorm=args.channel_layernorm,
|
|
head_mlp_layers=args.head_mlp_layers,
|
|
).to(device)
|
|
if args.compile and hasattr(torch, "compile"):
|
|
model = torch.compile(model) # type: ignore[assignment]
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
|
|
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
print(
|
|
f"device={device} model={args.model_type} hidden={args.hidden} "
|
|
f"layers={args.gcn_layers} params={n_params:,}"
|
|
)
|
|
print(f"time sampling power={args.time_power}")
|
|
print(
|
|
"loss weights: "
|
|
f"center={args.weight_center} omega={args.weight_omega} torsion={args.weight_torsion}; "
|
|
f"grad_clip={args.grad_clip}"
|
|
)
|
|
print("loss domain=%s torsion_wrapped_loss=always_on" % args.loss_domain)
|
|
print(
|
|
f"rotation_loss={args.rotation_loss} gcn_residual={args.gcn_residual} "
|
|
f"channel_layernorm={args.channel_layernorm} head_mlp_layers={args.head_mlp_layers}"
|
|
)
|
|
print(
|
|
"early_stop: "
|
|
f"patience={args.early_stop_patience} min_delta={args.early_stop_min_delta} "
|
|
f"check_every={args.early_stop_check_every} warmup={args.early_stop_warmup}"
|
|
)
|
|
|
|
nw = args.num_workers
|
|
if nw < 0:
|
|
nw = min(8, max(0, (os.cpu_count() or 4) - 1))
|
|
pf = args.prefetch_factor if nw > 0 else None
|
|
persistent = nw > 0
|
|
|
|
if args.sync_batch:
|
|
|
|
def get_train_batch() -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
return build_one_batch(
|
|
sample,
|
|
gt_state,
|
|
args.batch_size,
|
|
cpu,
|
|
args.time_power,
|
|
)
|
|
|
|
print(f"batching: sync (main thread) workers=0")
|
|
else:
|
|
dataset = SameGraphRandomRFMDataset(
|
|
sample,
|
|
gt_state_cpu=gt_state,
|
|
length=args.dataset_len,
|
|
time_power=args.time_power,
|
|
)
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
num_workers=nw,
|
|
prefetch_factor=pf,
|
|
persistent_workers=persistent,
|
|
pin_memory=args.pin_memory and torch.cuda.is_available(),
|
|
collate_fn=rfm_collate,
|
|
)
|
|
_batch_iter = infinite_batches(loader)
|
|
|
|
def get_train_batch() -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
return next(_batch_iter)
|
|
|
|
print(
|
|
f"batching: DataLoader num_workers={nw} prefetch_factor={pf!s} "
|
|
f"pin_memory={args.pin_memory and torch.cuda.is_available()} "
|
|
f"persistent_workers={persistent}"
|
|
)
|
|
|
|
t0 = time.perf_counter()
|
|
log_every = max(1, args.epochs // 10)
|
|
accum = max(1, args.accum)
|
|
best_train_mse = float("inf")
|
|
best_state: dict[str, torch.Tensor] | None = None
|
|
early_stop_best = float("inf")
|
|
early_stop_bad_checks = 0
|
|
stop_reason = "max_epochs"
|
|
stop_epoch = args.epochs - 1
|
|
|
|
for epoch in range(args.epochs):
|
|
optimizer.zero_grad(set_to_none=True)
|
|
sum_mse = torch.zeros((), device=device)
|
|
for _ in range(accum):
|
|
batched, time_b, tc, tw, tt = get_train_batch()
|
|
batched = batched.to(device, non_blocking=True)
|
|
time_b = time_b.to(device=device, dtype=torch.float32, non_blocking=True)
|
|
tc = tc.to(device=device, dtype=torch.float32, non_blocking=True)
|
|
tw = tw.to(device=device, dtype=torch.float32, non_blocking=True)
|
|
tt = tt.to(device=device, dtype=torch.float32, non_blocking=True)
|
|
|
|
pred = model(batched, time_b)
|
|
time_remain = (1.0 - time_b).clamp(min=1e-6)
|
|
if args.loss_domain == "displacement":
|
|
pred_center = pred.center * time_remain
|
|
pred_omega = pred.rot_omega * time_remain
|
|
pred_torsion = pred.torsion * time_remain
|
|
tgt_center = tc * time_remain
|
|
tgt_omega = tw * time_remain
|
|
tgt_torsion = tt * time_remain
|
|
else:
|
|
pred_center = pred.center
|
|
pred_omega = pred.rot_omega
|
|
pred_torsion = pred.torsion
|
|
tgt_center = tc
|
|
tgt_omega = tw
|
|
tgt_torsion = tt
|
|
|
|
eps = 1e-8
|
|
scale_c = (tgt_center * tgt_center).mean().clamp(min=eps)
|
|
scale_w = (tgt_omega * tgt_omega).mean().clamp(min=eps)
|
|
scale_t = (tgt_torsion * tgt_torsion).mean().clamp(min=eps)
|
|
if args.rotation_loss == "geodesic":
|
|
rot_loss = so3_geodesic_mse_loss(pred_omega, tgt_omega)
|
|
else:
|
|
rot_loss = F.mse_loss(pred_omega, tgt_omega)
|
|
mse = (
|
|
args.weight_center * (F.mse_loss(pred_center, tgt_center) / scale_c)
|
|
+ args.weight_omega * (rot_loss / scale_w)
|
|
+ args.weight_torsion
|
|
* (
|
|
torsion_wrapped_mse_loss(pred_torsion, tgt_torsion)
|
|
/ scale_t
|
|
)
|
|
)
|
|
(mse / accum).backward()
|
|
sum_mse = sum_mse + mse.detach()
|
|
|
|
if args.grad_clip > 0:
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
|
optimizer.step()
|
|
|
|
avg_mse = sum_mse / accum
|
|
avg_mse_val = float(avg_mse.detach().cpu())
|
|
if avg_mse_val < best_train_mse:
|
|
best_train_mse = avg_mse_val
|
|
core = model._orig_mod if hasattr(model, "_orig_mod") else model
|
|
best_state = {k: v.detach().cpu().clone() for k, v in core.state_dict().items()}
|
|
|
|
should_check_early_stop = (
|
|
args.early_stop_patience > 0
|
|
and epoch >= args.early_stop_warmup
|
|
and ((epoch - args.early_stop_warmup) % max(1, args.early_stop_check_every) == 0)
|
|
)
|
|
if should_check_early_stop:
|
|
improved = avg_mse_val < (early_stop_best - args.early_stop_min_delta)
|
|
if improved:
|
|
early_stop_best = avg_mse_val
|
|
early_stop_bad_checks = 0
|
|
else:
|
|
early_stop_bad_checks += 1
|
|
if early_stop_bad_checks >= args.early_stop_patience:
|
|
stop_reason = (
|
|
"early_stop_train_loss_plateau"
|
|
f"(patience={args.early_stop_patience},"
|
|
f"min_delta={args.early_stop_min_delta},"
|
|
f"check_every={max(1, args.early_stop_check_every)})"
|
|
)
|
|
stop_epoch = epoch
|
|
print(
|
|
f"early stopping at epoch {epoch}: "
|
|
f"bad_checks={early_stop_bad_checks}"
|
|
)
|
|
break
|
|
|
|
if epoch % log_every == 0 or epoch == args.epochs - 1:
|
|
print(f"epoch {epoch:6d} train_mse {avg_mse_val:.6f} best_train_mse {best_train_mse:.6f}")
|
|
|
|
stop_epoch = epoch
|
|
|
|
if device.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
print(
|
|
f"training done in {time.perf_counter() - t0:.1f}s "
|
|
f"(stop_reason={stop_reason}, stop_epoch={stop_epoch})"
|
|
)
|
|
|
|
ckpt_path = ""
|
|
if best_state is not None:
|
|
core = model._orig_mod if hasattr(model, "_orig_mod") else model
|
|
core.load_state_dict(best_state)
|
|
print(f"loaded best model from training (best_train_mse={best_train_mse:.6f})")
|
|
artifacts_dir = os.path.join(here, "artifacts")
|
|
os.makedirs(artifacts_dir, exist_ok=True)
|
|
ckpt_path = os.path.join(artifacts_dir, "latest_eval_best_model.pt")
|
|
ckpt = {
|
|
"state_dict": best_state,
|
|
"model_type": args.model_type,
|
|
"hidden": int(args.hidden),
|
|
"num_layers": int(args.gcn_layers),
|
|
"num_nodes": int(sample.num_nodes),
|
|
"num_torsions": int(sample.bond_endpoints.shape[0]),
|
|
"sample_sdf": args.sdf,
|
|
"sample_pickle_path": args.pickle_path,
|
|
"sample_pickle_pos": int(args.pickle_pos),
|
|
"eval_runs": int(args.eval_runs),
|
|
"rollout_steps": int(args.rollout_steps),
|
|
}
|
|
torch.save(ckpt, ckpt_path)
|
|
print(f"saved best checkpoint: {ckpt_path}")
|
|
|
|
# Final test metric: 100 rollouts mean final-structure RMSD
|
|
final_mean_rmsd_100 = evaluate_mean_rmsd_100(
|
|
model,
|
|
sample,
|
|
sample.gt_coords,
|
|
device=device,
|
|
num_runs=args.eval_runs,
|
|
rollout_steps=args.rollout_steps,
|
|
)
|
|
print(f"final_test_mean_rmsd_{args.eval_runs}: {final_mean_rmsd_100:.6f}")
|
|
|
|
reports_dir = os.path.join(here, "reports")
|
|
os.makedirs(reports_dir, exist_ok=True)
|
|
latest_path = os.path.join(reports_dir, "latest_eval.json")
|
|
timestamp = datetime.now(timezone.utc).isoformat()
|
|
cmd = " ".join(sys.argv)
|
|
|
|
latest_report = {
|
|
"mean_rmsd_100": float(final_mean_rmsd_100),
|
|
"num_runs": int(args.eval_runs),
|
|
"timestamp_utc": timestamp,
|
|
"command": cmd,
|
|
"notes": "Final test metric from 100 random-initialized rollouts to time=1.",
|
|
"best_train_mse": float(best_train_mse),
|
|
"model_source": "best_train_checkpoint",
|
|
"checkpoint_path": ckpt_path,
|
|
"stop_reason": stop_reason,
|
|
"stop_epoch": int(stop_epoch),
|
|
}
|
|
with open(latest_path, "w", encoding="utf-8") as f:
|
|
json.dump(latest_report, f, indent=2)
|
|
|
|
# --- Visualizations ---
|
|
sample.coords = sample.gt_coords.clone()
|
|
save_visualization(sample, os.path.join(args.out_dir, "01_gt_reset.sdf"))
|
|
|
|
random_trans = generate_random_trans(
|
|
torch.tensor([0, 0, 0], dtype=torch.float32), torch.tensor([3, 3, 3], dtype=torch.float32)
|
|
)
|
|
random_rot = generate_rotation_matrix(generate_random_quaternion())
|
|
random_torsion = generate_random_torsion_angles(sample.bond_endpoints.shape[0])
|
|
random_state = RigidTorsionState(center=random_trans, rot=random_rot, torsion=random_torsion)
|
|
|
|
vis = deepcopy(sample)
|
|
vis.coords = vis.gt_coords.clone()
|
|
apply_trans(vis, random_trans)
|
|
apply_rot_matrix(vis, random_rot)
|
|
apply_torsion(vis, random_torsion, vis.movable_mask, vis.bond_endpoints)
|
|
save_visualization(vis, os.path.join(args.out_dir, "02_random_init.sdf"))
|
|
|
|
velocity_gt = rigid_torsion_velocity(random_state, target_state, torch.tensor(0.0))
|
|
state_list_gt = [get_state(random_state, velocity_gt, float(t)) for t in torch.linspace(0, 1, 50)]
|
|
save_video_visualization(sample, state_list_gt, os.path.join(args.out_dir, "03_trajectory_gt_velocity.sdf"))
|
|
|
|
# Learned trajectory: Euler steps using model velocity (batch size 1)
|
|
model.eval()
|
|
state = RigidTorsionState(
|
|
center=random_state.center.to(device),
|
|
rot=random_state.rot.to(device),
|
|
torsion=random_state.torsion.to(device),
|
|
)
|
|
n_frames = 100
|
|
learned_states: list[RigidTorsionState] = []
|
|
for k in range(n_frames):
|
|
learned_states.append(
|
|
RigidTorsionState(
|
|
center=state.center.detach().cpu(),
|
|
rot=state.rot.detach().cpu(),
|
|
torsion=state.torsion.detach().cpu(),
|
|
)
|
|
)
|
|
t_val = k / max(1, n_frames - 1)
|
|
g = deepcopy(sample)
|
|
g.coords = sample.gt_coords.clone()
|
|
apply_trans(g, state.center.cpu())
|
|
apply_rot_matrix(g, state.rot.cpu())
|
|
apply_torsion(g, state.torsion.cpu(), g.movable_mask, g.bond_endpoints)
|
|
g = g.to(device)
|
|
batch = Batch.from_data_list([g]).to(device)
|
|
tb = torch.tensor([[t_val]], device=device, dtype=torch.float32)
|
|
with torch.no_grad():
|
|
v = model(batch, tb)
|
|
vc = v.center.squeeze(0)
|
|
vw = v.rot_omega.squeeze(0)
|
|
vt = v.torsion.squeeze(0)
|
|
dt = 1.0 / max(1, n_frames - 1)
|
|
state = RigidTorsionState(
|
|
center=state.center + vc * dt,
|
|
rot=state.rot @ so3_exp(vw * dt),
|
|
torsion=state.torsion + vt * dt,
|
|
)
|
|
|
|
save_video_visualization(sample, learned_states, os.path.join(args.out_dir, "04_trajectory_learned_model.sdf"))
|
|
|
|
print("Wrote SDFs under:", args.out_dir)
|
|
for name in sorted(os.listdir(args.out_dir)):
|
|
if name.endswith(".sdf"):
|
|
print(" ", name)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|