Restores gcn_residual/graph_readout parity with training in update_best_artifacts; migration reads BEST_PRACTICE command when ckpt keys are missing. Made-with: Cursor
1578 lines
59 KiB
Python
1578 lines
59 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Consolidated RFM toy training + visualization from sandbox_rfm.ipynb.
|
||
- DataLoader with num_workers + prefetch builds CPU batches in parallel; tensors moved to GPU for train.
|
||
- pin_memory defaults False (fewer shm/pin-thread issues); enable with --pin-memory if stable.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import math
|
||
import os
|
||
import pickle
|
||
import sys
|
||
import time
|
||
from collections import deque
|
||
from copy import deepcopy
|
||
from dataclasses import dataclass
|
||
from datetime import datetime, timezone
|
||
from typing import Any
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from torch.utils.data import DataLoader, Dataset
|
||
from rdkit import Chem
|
||
from rdkit.Chem import SDWriter
|
||
from rdkit.Geometry.rdGeometry import Point3D
|
||
from torch import nn
|
||
from torch_geometric.data import Batch, Data
|
||
from torch_geometric.nn import GCNConv, global_add_pool, global_mean_pool
|
||
from torch_geometric.utils import add_self_loops, softmax
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Geometry / torsion (from notebook)
|
||
# ---------------------------------------------------------------------------
|
||
def generate_random_trans(mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
||
return torch.randn(mean.shape, device=mean.device, dtype=mean.dtype) * std + mean
|
||
|
||
|
||
def apply_trans(x: Data, t_vec: torch.Tensor) -> Data:
|
||
x.coords = x.coords + t_vec
|
||
return x
|
||
|
||
|
||
def generate_random_quaternion() -> torch.Tensor:
|
||
u1, u2, u3 = torch.rand(3)
|
||
q1 = (1 - u1).sqrt() * torch.sin(2 * torch.pi * u2)
|
||
q2 = (1 - u1).sqrt() * torch.cos(2 * torch.pi * u2)
|
||
q3 = u1.sqrt() * torch.sin(2 * torch.pi * u3)
|
||
q4 = u1.sqrt() * torch.cos(2 * torch.pi * u3)
|
||
return torch.stack([q1, q2, q3, q4], dim=-1)
|
||
|
||
|
||
def generate_rotation_matrix(q: torch.Tensor) -> torch.Tensor:
|
||
rot = torch.tensor(
|
||
[
|
||
[
|
||
1 - 2 * q[2] ** 2 - 2 * q[3] ** 2,
|
||
2 * (q[1] * q[2] - q[0] * q[3]),
|
||
2 * (q[1] * q[3] + q[0] * q[2]),
|
||
],
|
||
[
|
||
2 * (q[1] * q[2] + q[0] * q[3]),
|
||
1 - 2 * q[1] ** 2 - 2 * q[3] ** 2,
|
||
2 * (q[2] * q[3] - q[0] * q[1]),
|
||
],
|
||
[
|
||
2 * (q[1] * q[3] - q[0] * q[2]),
|
||
2 * (q[2] * q[3] + q[0] * q[1]),
|
||
1 - 2 * q[1] ** 2 - 2 * q[2] ** 2,
|
||
],
|
||
],
|
||
device=q.device,
|
||
dtype=q.dtype,
|
||
)
|
||
return rot
|
||
|
||
|
||
def apply_rot_matrix(x: Data, rot_matrix: torch.Tensor) -> Data:
|
||
center = torch.mean(x.coords, dim=0)
|
||
coords = (x.coords - center) @ rot_matrix.T
|
||
x.coords = coords + center
|
||
return x
|
||
|
||
|
||
def _rodrigues(
|
||
points: torch.Tensor,
|
||
pivot: torch.Tensor,
|
||
axis_unit: torch.Tensor,
|
||
angle: torch.Tensor,
|
||
) -> torch.Tensor:
|
||
v = points - pivot
|
||
c, s = torch.cos(angle), torch.sin(angle)
|
||
cross = torch.cross(axis_unit.expand_as(v), v, dim=-1)
|
||
dot_u = (v * axis_unit).sum(-1, keepdim=True)
|
||
return pivot + v * c + cross * s + axis_unit * (dot_u * (1 - c))
|
||
|
||
|
||
def apply_torsion(
|
||
x: Any,
|
||
angles: torch.Tensor,
|
||
movable_mask: torch.Tensor,
|
||
bond_endpoints: torch.Tensor,
|
||
) -> Any:
|
||
coords = x.coords
|
||
B = angles.shape[0]
|
||
bond_endpoints = bond_endpoints.to(device=coords.device, dtype=torch.long)
|
||
movable_mask = movable_mask.to(device=coords.device)
|
||
angles = angles.to(device=coords.device, dtype=coords.dtype).reshape(B)
|
||
for b in range(B):
|
||
i, j = int(bond_endpoints[b, 0]), int(bond_endpoints[b, 1])
|
||
m = movable_mask[b]
|
||
d = coords[j] - coords[i]
|
||
u = d / d.norm()
|
||
idx = m.nonzero(as_tuple=True)[0]
|
||
coords.index_put_((idx,), _rodrigues(coords[idx], coords[i], u, angles[b]))
|
||
return x
|
||
|
||
|
||
def generate_random_torsion_angles(
|
||
num_bonds: int,
|
||
dtype: torch.dtype = torch.float32,
|
||
device: torch.device | None = None,
|
||
) -> torch.Tensor:
|
||
t = torch.empty(num_bonds, dtype=dtype, device=device)
|
||
t.uniform_(0, 2 * torch.pi)
|
||
return t
|
||
|
||
|
||
def _movable_j_side(edge_index: torch.Tensor, i: int, j: int, num_atoms: int) -> torch.Tensor:
|
||
ei = edge_index[:2]
|
||
adj = [[] for _ in range(num_atoms)]
|
||
for e in range(ei.shape[1]):
|
||
u, v = int(ei[0, e]), int(ei[1, e])
|
||
if {u, v} == {i, j}:
|
||
continue
|
||
adj[u].append(v)
|
||
adj[v].append(u)
|
||
seen = {j}
|
||
q = deque([j])
|
||
while q:
|
||
u = q.popleft()
|
||
for v in adj[u]:
|
||
if v not in seen:
|
||
seen.add(v)
|
||
q.append(v)
|
||
m = torch.zeros(num_atoms, dtype=torch.bool)
|
||
m[list(seen)] = True
|
||
return m
|
||
|
||
|
||
def bfs_torsion_masks(x: Any, root: int = 0) -> tuple[torch.Tensor, torch.Tensor]:
|
||
rot_pat = Chem.MolFromSmarts("[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]")
|
||
rot_set = {frozenset((int(a), int(b))) for a, b in x.rdmol.GetSubstructMatches(rot_pat)}
|
||
N = int(x.coords.shape[0])
|
||
ei = x.edge_index
|
||
adj = [[] for _ in range(N)]
|
||
for e in range(ei.shape[1]):
|
||
u, v = int(ei[0, e]), int(ei[1, e])
|
||
adj[u].append(v)
|
||
adj[v].append(u)
|
||
for u in range(N):
|
||
adj[u].sort()
|
||
depth = [-1] * N
|
||
depth[root] = 0
|
||
q = deque([root])
|
||
ordered_ij: list[tuple[int, int]] = []
|
||
tree_pairs: set[frozenset[int]] = set()
|
||
while q:
|
||
u = q.popleft()
|
||
for v in adj[u]:
|
||
if depth[v] != -1:
|
||
continue
|
||
depth[v] = depth[u] + 1
|
||
q.append(v)
|
||
fs = frozenset((u, v))
|
||
tree_pairs.add(fs)
|
||
if fs in rot_set:
|
||
ordered_ij.append((u, v))
|
||
|
||
def orient(fs: frozenset[int]) -> tuple[int, int]:
|
||
a, b = tuple(fs)
|
||
da, db = depth[a], depth[b]
|
||
if da < db or (da == db and a < b):
|
||
return (a, b)
|
||
return (b, a)
|
||
|
||
for fs in sorted(rot_set - tree_pairs, key=orient):
|
||
ordered_ij.append(orient(fs))
|
||
bond_endpoints = torch.tensor(ordered_ij, dtype=torch.long)
|
||
movable_mask = torch.stack([_movable_j_side(ei, i, j, N) for i, j in ordered_ij])
|
||
return bond_endpoints, movable_mask
|
||
|
||
|
||
def vel_center(current: torch.Tensor, target: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||
return (target - current) / (1 - t).clamp(min=1e-6)
|
||
|
||
|
||
def wrap_angle(d: torch.Tensor) -> torch.Tensor:
|
||
return torch.atan2(torch.sin(d), torch.cos(d))
|
||
|
||
|
||
def vel_torsion(current: torch.Tensor, target: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||
return wrap_angle(target - current) / (1 - t).clamp(min=1e-6)
|
||
|
||
|
||
def torsion_wrapped_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||
return torch.mean(wrap_angle(pred - target) ** 2)
|
||
|
||
|
||
def upper_quantile_mean(values: torch.Tensor, quantile: float) -> torch.Tensor:
|
||
"""Mean of upper-quantile tail values."""
|
||
q = float(max(0.0, min(1.0, quantile)))
|
||
if q <= 0.0:
|
||
return values.mean()
|
||
if q >= 1.0:
|
||
return values.max()
|
||
n = int(values.numel())
|
||
k = max(1, int(math.ceil(n * (1.0 - q))))
|
||
top_vals, _ = torch.topk(values.reshape(-1), k=k, largest=True)
|
||
return top_vals.mean()
|
||
|
||
|
||
def so3_geodesic_mse_loss(pred_omega: torch.Tensor, target_omega: torch.Tensor) -> torch.Tensor:
|
||
"""Geodesic MSE between batched axis-angle increments."""
|
||
if pred_omega.ndim != 2 or pred_omega.shape[-1] != 3:
|
||
raise ValueError("pred_omega must be [B,3]")
|
||
pred_omega = torch.nan_to_num(pred_omega, nan=0.0, posinf=0.0, neginf=0.0).clamp(-20.0, 20.0)
|
||
target_omega = torch.nan_to_num(target_omega, nan=0.0, posinf=0.0, neginf=0.0).clamp(-20.0, 20.0)
|
||
losses = []
|
||
for i in range(pred_omega.shape[0]):
|
||
r_pred = so3_exp(pred_omega[i])
|
||
r_tgt = so3_exp(target_omega[i])
|
||
delta = so3_log(r_pred.T @ r_tgt)
|
||
delta = torch.nan_to_num(delta, nan=0.0, posinf=0.0, neginf=0.0).clamp(-10.0, 10.0)
|
||
losses.append(torch.sum(delta * delta))
|
||
return torch.stack(losses).mean()
|
||
|
||
|
||
def clip_vector_norm_batch(x: torch.Tensor, max_norm: float) -> torch.Tensor:
|
||
if max_norm <= 0:
|
||
return x
|
||
norms = torch.linalg.norm(x, dim=-1, keepdim=True).clamp(min=1e-9)
|
||
scale = torch.clamp(max_norm / norms, max=1.0)
|
||
return x * scale
|
||
|
||
|
||
def skew(w: torch.Tensor) -> torch.Tensor:
|
||
z = torch.zeros((), device=w.device, dtype=w.dtype)
|
||
return torch.stack(
|
||
[
|
||
torch.stack([z, -w[2], w[1]]),
|
||
torch.stack([w[2], z, -w[0]]),
|
||
torch.stack([-w[1], w[0], z]),
|
||
]
|
||
)
|
||
|
||
|
||
def so3_log(R: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
||
tr = (R.trace() - 1) * 0.5
|
||
tr = tr.clamp(-1 + eps, 1 - eps)
|
||
theta = torch.acos(tr)
|
||
w = torch.stack(
|
||
[R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]]
|
||
) / (2 * torch.sin(theta) + eps)
|
||
return w * theta
|
||
|
||
|
||
def so3_exp(omega: torch.Tensor) -> torch.Tensor:
|
||
theta = omega.norm()
|
||
w = omega / (theta + 1e-12)
|
||
K = skew(w)
|
||
I = torch.eye(3, device=omega.device, dtype=omega.dtype)
|
||
return I + torch.sin(theta) * K + (1 - torch.cos(theta)) * (K @ K)
|
||
|
||
|
||
def vel_rot_mat(current: torch.Tensor, target: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||
om = so3_log(current.T @ target)
|
||
return om / (1 - t).clamp(min=1e-6)
|
||
|
||
|
||
@dataclass
|
||
class RigidTorsionState:
|
||
center: torch.Tensor
|
||
rot: torch.Tensor
|
||
torsion: torch.Tensor
|
||
|
||
|
||
@dataclass
|
||
class RigidTorsionVelocity:
|
||
center: torch.Tensor
|
||
rot_omega: torch.Tensor
|
||
torsion: torch.Tensor
|
||
|
||
|
||
def rigid_torsion_velocity(
|
||
current: RigidTorsionState, target: RigidTorsionState, t: torch.Tensor
|
||
) -> RigidTorsionVelocity:
|
||
t = torch.as_tensor(t)
|
||
vc = vel_center(current.center, target.center, t)
|
||
vt = vel_torsion(current.torsion, target.torsion, t)
|
||
omega = vel_rot_mat(current.rot, target.rot, t)
|
||
return RigidTorsionVelocity(vc, omega, vt)
|
||
|
||
|
||
def get_state(
|
||
current: RigidTorsionState, velocity: RigidTorsionVelocity, t: float
|
||
) -> RigidTorsionState:
|
||
t = torch.as_tensor(t)
|
||
center = current.center + velocity.center * t
|
||
rot = current.rot @ so3_exp(velocity.rot_omega * t)
|
||
torsion = current.torsion + velocity.torsion * t
|
||
return RigidTorsionState(center, rot, torsion)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Visualization
|
||
# ---------------------------------------------------------------------------
|
||
def save_visualization(sample: Data, path: str) -> None:
|
||
mol = deepcopy(sample.rdmol)
|
||
conformer = mol.GetConformer()
|
||
coords = sample.coords
|
||
for i, coord in enumerate(coords):
|
||
x, y, z = coord.tolist()
|
||
conformer.SetAtomPosition(i, Point3D(x, y, z))
|
||
Chem.MolToMolFile(mol, path)
|
||
|
||
|
||
def save_video_visualization(
|
||
sample: Data,
|
||
states: list[RigidTorsionState],
|
||
path: str,
|
||
) -> None:
|
||
mol = deepcopy(sample.rdmol)
|
||
sample.coords = sample.gt_coords.clone()
|
||
with open(path, "w") as f:
|
||
writer = SDWriter(f)
|
||
for state in states:
|
||
sample_copy = deepcopy(sample)
|
||
apply_trans(sample_copy, state.center)
|
||
apply_rot_matrix(sample_copy, state.rot)
|
||
apply_torsion(sample_copy, state.torsion, sample_copy.movable_mask, sample_copy.bond_endpoints)
|
||
coords = sample_copy.coords
|
||
conformer = mol.GetConformer()
|
||
for i, coord in enumerate(coords):
|
||
x, y, z = coord.tolist()
|
||
conformer.SetAtomPosition(i, Point3D(x, y, z))
|
||
writer.write(mol)
|
||
writer.close()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Evaluation
|
||
# ---------------------------------------------------------------------------
|
||
def kabsch_rmsd(pred: torch.Tensor, target: torch.Tensor) -> float:
|
||
"""RMSD after rigid alignment (Kabsch), shape: [N, 3]."""
|
||
if not torch.isfinite(pred).all() or not torch.isfinite(target).all():
|
||
return 1e6
|
||
p = pred - pred.mean(dim=0, keepdim=True)
|
||
q = target - target.mean(dim=0, keepdim=True)
|
||
h = p.T @ q
|
||
try:
|
||
u, _, v_t = torch.linalg.svd(h)
|
||
except RuntimeError:
|
||
return 1e6
|
||
r = v_t.T @ u.T
|
||
if torch.det(r) < 0:
|
||
v_t[-1, :] *= -1
|
||
r = v_t.T @ u.T
|
||
p_aligned = p @ r
|
||
return float(torch.sqrt(torch.mean(torch.sum((p_aligned - q) ** 2, dim=-1))).item())
|
||
|
||
|
||
def graph_from_state_cpu(sample: Data, center: torch.Tensor, rot: torch.Tensor, torsion: torch.Tensor) -> Data:
|
||
g = deepcopy(sample)
|
||
g.coords = sample.gt_coords.clone()
|
||
apply_trans(g, center)
|
||
apply_rot_matrix(g, rot)
|
||
apply_torsion(g, torsion, g.movable_mask, g.bond_endpoints)
|
||
return g
|
||
|
||
|
||
def evaluate_mean_rmsd_100(
|
||
model: nn.Module,
|
||
sample: Data,
|
||
target_coords: torch.Tensor,
|
||
device: torch.device,
|
||
num_runs: int = 100,
|
||
rollout_steps: int = 100,
|
||
omega_max_norm: float = 0.0,
|
||
) -> float:
|
||
"""Run 100 one-shot predictions to time=1 and average final-structure RMSD."""
|
||
del rollout_steps
|
||
model.eval()
|
||
nb = sample.bond_endpoints.shape[0]
|
||
centers = []
|
||
rots = []
|
||
torsions = []
|
||
for _ in range(num_runs):
|
||
centers.append(generate_random_trans(torch.zeros(3), torch.tensor([3.0, 3.0, 3.0])))
|
||
rots.append(generate_rotation_matrix(generate_random_quaternion()))
|
||
torsions.append(generate_random_torsion_angles(nb))
|
||
center_t = torch.stack(centers, dim=0) # [R,3]
|
||
rot_t = torch.stack(rots, dim=0) # [R,3,3]
|
||
torsion_t = torch.stack(torsions, dim=0) # [R,B]
|
||
|
||
graphs = [
|
||
graph_from_state_cpu(sample, center_t[i], rot_t[i], torsion_t[i])
|
||
for i in range(num_runs)
|
||
]
|
||
batch = Batch.from_data_list(graphs).to(device)
|
||
tb = torch.zeros((num_runs, 1), device=device, dtype=torch.float32)
|
||
with torch.no_grad():
|
||
v = model(batch, tb)
|
||
vc = v.center.detach().cpu()
|
||
vw = clip_vector_norm_batch(v.rot_omega.detach().cpu(), omega_max_norm)
|
||
vt = v.torsion.detach().cpu()
|
||
center_t = center_t + vc
|
||
torsion_t = torsion_t + vt
|
||
for i in range(num_runs):
|
||
rot_t[i] = rot_t[i] @ so3_exp(vw[i])
|
||
|
||
rmsds = []
|
||
for i in range(num_runs):
|
||
final_graph = graph_from_state_cpu(sample, center_t[i], rot_t[i], torsion_t[i])
|
||
rmsds.append(kabsch_rmsd(final_graph.coords, target_coords))
|
||
return float(sum(rmsds) / len(rmsds))
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Model
|
||
# ---------------------------------------------------------------------------
|
||
class TorsionSubgraphSpec(nn.Module):
|
||
"""Movable-side induced subgraph for one torsion (full-graph node ids + local edge_index)."""
|
||
|
||
def __init__(self, nodes_full: torch.Tensor, edge_index_local: torch.Tensor):
|
||
super().__init__()
|
||
self.register_buffer("nodes_full", nodes_full, persistent=True)
|
||
self.register_buffer("edge_index_local", edge_index_local, persistent=True)
|
||
|
||
|
||
def build_torsion_subgraph_specs(sample: Data) -> nn.ModuleList:
|
||
"""Induced subgraph for torsion k: movable side ∪ rotatable bond endpoints (i,j).
|
||
|
||
Endpoints carry hinge geometry even when one atom lies on the ``fixed'' side of the mask.
|
||
"""
|
||
ei = sample.edge_index
|
||
if ei.device.type != "cpu":
|
||
ei = ei.cpu()
|
||
mv = sample.movable_mask
|
||
if mv.device.type != "cpu":
|
||
mv = mv.cpu()
|
||
be = sample.bond_endpoints
|
||
if be.device.type != "cpu":
|
||
be = be.cpu()
|
||
t_count = int(sample.bond_endpoints.shape[0])
|
||
modules: list[TorsionSubgraphSpec] = []
|
||
for k in range(t_count):
|
||
mask_row = mv[k].bool()
|
||
mov = torch.where(mask_row)[0].long()
|
||
i_t, j_t = int(be[k, 0].item()), int(be[k, 1].item())
|
||
hinge = torch.tensor([i_t, j_t], dtype=torch.long)
|
||
nodes_full = torch.unique(torch.cat([mov, hinge]), sorted=True)
|
||
if nodes_full.numel() == 0:
|
||
raise RuntimeError(f"torsion {k}: empty movable side; cannot build subgraph")
|
||
node_set = set(int(x) for x in nodes_full.tolist())
|
||
old_to_new = {old: i for i, old in enumerate(sorted(node_set))}
|
||
es: list[int] = []
|
||
ed: list[int] = []
|
||
for e in range(ei.shape[1]):
|
||
a, b = int(ei[0, e]), int(ei[1, e])
|
||
if a in node_set and b in node_set:
|
||
es.append(old_to_new[a])
|
||
ed.append(old_to_new[b])
|
||
if len(es) > 0:
|
||
edge_index_local = torch.tensor([es, ed], dtype=torch.long)
|
||
else:
|
||
nk = nodes_full.numel()
|
||
edge_index_local = torch.stack(
|
||
[torch.arange(nk, dtype=torch.long), torch.arange(nk, dtype=torch.long)], dim=0
|
||
)
|
||
modules.append(TorsionSubgraphSpec(nodes_full, edge_index_local))
|
||
return nn.ModuleList(modules)
|
||
|
||
|
||
class AttentionGraphReadout(nn.Module):
|
||
"""Learned softmax weights over nodes (distinct inductive bias vs mean pooling)."""
|
||
|
||
def __init__(self, node_dim: int, gate_hidden: int = 64):
|
||
super().__init__()
|
||
self.gate = nn.Sequential(
|
||
nn.Linear(node_dim, gate_hidden),
|
||
nn.SiLU(),
|
||
nn.Linear(gate_hidden, 1),
|
||
)
|
||
|
||
def forward(self, h: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
|
||
scores = self.gate(h).squeeze(-1)
|
||
alpha = softmax(scores, batch)
|
||
return global_add_pool(h * alpha.unsqueeze(-1), batch)
|
||
|
||
|
||
class RFMModel(nn.Module):
|
||
"""Model for velocity prediction. Supports GCN and fixed-order MLP encoders."""
|
||
|
||
def __init__(
|
||
self,
|
||
model_type: str,
|
||
num_nodes: int,
|
||
num_torsions: int,
|
||
hidden: int = 128,
|
||
num_layers: int = 4,
|
||
gcn_residual: bool = False,
|
||
channel_layernorm: bool = False,
|
||
head_mlp_layers: int = 1,
|
||
graph_readout: str = "mean",
|
||
torsion_head_mode: str = "global",
|
||
torsion_subgraphs: nn.ModuleList | None = None,
|
||
torsion_subgraph_output_clip: float = 0.0,
|
||
torsion_sub_gcn_layers: int | None = None,
|
||
):
|
||
super().__init__()
|
||
self.model_type = model_type
|
||
self.num_nodes = num_nodes
|
||
self.num_torsions = num_torsions
|
||
self.hidden = hidden
|
||
self.num_layers = num_layers
|
||
self.gcn_residual = gcn_residual
|
||
self.channel_layernorm = channel_layernorm
|
||
self.act = nn.SiLU()
|
||
self.head_mlp_layers = max(1, int(head_mlp_layers))
|
||
self.graph_readout = graph_readout
|
||
self.torsion_head_mode = torsion_head_mode
|
||
self.torsion_subgraph_output_clip = float(max(0.0, torsion_subgraph_output_clip))
|
||
self.sub_gcn_skip: nn.ModuleList | None = None
|
||
|
||
if model_type == "gcn":
|
||
self.convs = nn.ModuleList()
|
||
self.convs.append(GCNConv(3, hidden))
|
||
for _ in range(num_layers - 1):
|
||
self.convs.append(GCNConv(hidden, hidden))
|
||
self.gcn_skip = nn.ModuleList([nn.Identity()])
|
||
for _ in range(num_layers - 1):
|
||
self.gcn_skip.append(nn.Identity())
|
||
trunk_dim = hidden
|
||
if graph_readout == "mean":
|
||
self._pool = None
|
||
elif graph_readout == "attention":
|
||
self._pool = AttentionGraphReadout(trunk_dim, gate_hidden=max(32, min(128, hidden)))
|
||
else:
|
||
raise ValueError(f"Unsupported graph_readout: {graph_readout}")
|
||
elif model_type == "mlp":
|
||
self._pool = None
|
||
mlp_layers = []
|
||
in_dim = num_nodes * 3
|
||
for _ in range(num_layers):
|
||
mlp_layers.append(nn.Linear(in_dim, hidden))
|
||
mlp_layers.append(nn.SiLU())
|
||
in_dim = hidden
|
||
self.encoder = nn.Sequential(*mlp_layers)
|
||
trunk_dim = hidden
|
||
else:
|
||
raise ValueError(f"Unsupported model_type: {model_type}")
|
||
|
||
self.time_feat_dim = 3
|
||
d = trunk_dim + self.time_feat_dim
|
||
self.shared_norm = nn.LayerNorm(d) if channel_layernorm else nn.Identity()
|
||
|
||
def make_head(out_dim: int) -> nn.Sequential:
|
||
layers: list[nn.Module] = []
|
||
in_dim = d
|
||
for _ in range(self.head_mlp_layers - 1):
|
||
layers.append(nn.Linear(in_dim, d))
|
||
if channel_layernorm:
|
||
layers.append(nn.LayerNorm(d))
|
||
layers.append(nn.SiLU())
|
||
in_dim = d
|
||
layers.append(nn.Linear(in_dim, out_dim))
|
||
return nn.Sequential(*layers)
|
||
|
||
self.trans_head = make_head(3)
|
||
self.rot_head = make_head(3)
|
||
if torsion_head_mode == "global":
|
||
self.torsion_head = make_head(num_torsions)
|
||
self.torsion_pair_mlp = None
|
||
self.torsion_pair_in_norm = None
|
||
self.torsion_subgraphs = None
|
||
self.sub_convs = None
|
||
self.sub_pool_norm = None
|
||
self.sub_gcn_skip = None
|
||
elif torsion_head_mode == "bond_pair":
|
||
if model_type != "gcn":
|
||
raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn")
|
||
if torsion_subgraphs is None or len(torsion_subgraphs) != num_torsions:
|
||
raise ValueError("bond_pair requires torsion_subgraphs built from sample.movable_mask")
|
||
# Dedicated subgraph encoder (weights not shared with full-graph convs).
|
||
# First-layer input = coords ⊕ full-graph node embedding (same trunk forward); avoids
|
||
# throwing away trunk geometry context when re-encoding the movable fragment only.
|
||
ns = int(num_layers if torsion_sub_gcn_layers is None else torsion_sub_gcn_layers)
|
||
ns = max(1, min(ns, num_layers))
|
||
self.sub_num_layers = ns
|
||
self.sub_convs = nn.ModuleList()
|
||
self.sub_convs.append(GCNConv(3 + hidden, hidden))
|
||
for _ in range(ns - 1):
|
||
self.sub_convs.append(GCNConv(hidden, hidden))
|
||
self.sub_gcn_skip = nn.ModuleList([nn.Identity() for _ in range(ns)])
|
||
# Global pooled context + mean pool over subgraph GCN.
|
||
pair_in = d + hidden
|
||
ph = max(64, min(256, hidden * 2))
|
||
self.torsion_head = None
|
||
self.torsion_subgraphs = torsion_subgraphs
|
||
self.sub_pool_norm = nn.LayerNorm(hidden)
|
||
self.torsion_pair_in_norm = nn.LayerNorm(pair_in)
|
||
self.torsion_pair_mlp = nn.Sequential(
|
||
nn.Linear(pair_in, ph),
|
||
nn.SiLU(),
|
||
nn.Linear(ph, ph),
|
||
nn.SiLU(),
|
||
nn.Linear(ph, 1),
|
||
)
|
||
def _small_linear_init(m: nn.Module) -> None:
|
||
if isinstance(m, nn.Linear):
|
||
nn.init.xavier_uniform_(m.weight, gain=0.25)
|
||
if m.bias is not None:
|
||
nn.init.zeros_(m.bias)
|
||
|
||
self.torsion_pair_mlp.apply(_small_linear_init)
|
||
else:
|
||
raise ValueError(f"Unsupported torsion_head_mode: {torsion_head_mode}")
|
||
|
||
def _torsion_from_subgraph_specs(self, x: Data, pooled: torch.Tensor, h_full: torch.Tensor) -> torch.Tensor:
|
||
"""Batched subgraph GCN; node input is [coords | h_full] gathered onto each torsion subgraph."""
|
||
if self.num_torsions == 0:
|
||
return pooled.new_zeros((pooled.shape[0], 0))
|
||
assert self.torsion_subgraphs is not None
|
||
assert self.sub_convs is not None
|
||
assert self.torsion_pair_mlp is not None and self.torsion_pair_in_norm is not None
|
||
if not hasattr(x, "ptr") or x.ptr is None:
|
||
raise RuntimeError("bond_pair torsion head requires Batch.ptr (use Batch.from_data_list)")
|
||
|
||
B = int(pooled.shape[0])
|
||
T = self.num_torsions
|
||
ptr = x.ptr.to(device=pooled.device)
|
||
device = pooled.device
|
||
dtype = pooled.dtype
|
||
graphs: list[Data] = []
|
||
for k in range(T):
|
||
spec = self.torsion_subgraphs[k]
|
||
nodes_f = spec.nodes_full.to(device=device, dtype=torch.long)
|
||
ei_loc = spec.edge_index_local.to(device=device)
|
||
for b in range(B):
|
||
base = int(ptr[b].item())
|
||
idx = base + nodes_f
|
||
coords_sub = x.coords[idx].to(dtype=dtype)
|
||
h_sub = h_full[idx].to(dtype=dtype)
|
||
x_sub = torch.cat([coords_sub, h_sub], dim=-1)
|
||
graphs.append(Data(x=x_sub, edge_index=ei_loc))
|
||
batch_sub = Batch.from_data_list(graphs).to(device)
|
||
h = batch_sub.x
|
||
edge_index, _ = add_self_loops(batch_sub.edge_index, num_nodes=batch_sub.num_nodes)
|
||
batch_vec = batch_sub.batch
|
||
assert self.sub_gcn_skip is not None
|
||
for idx, conv in enumerate(self.sub_convs):
|
||
h_new = self.act(conv(h, edge_index))
|
||
if self.gcn_residual and h_new.shape == h.shape:
|
||
h = h_new + self.sub_gcn_skip[idx](h)
|
||
else:
|
||
h = h_new
|
||
sub_pooled = global_mean_pool(h, batch_vec)
|
||
assert self.sub_pool_norm is not None
|
||
sub_pooled = self.sub_pool_norm(sub_pooled)
|
||
sub_bt = sub_pooled.view(T, B, self.hidden).transpose(0, 1)
|
||
pooled_exp = pooled.unsqueeze(1).expand(B, T, pooled.shape[-1])
|
||
feat = torch.cat([pooled_exp, sub_bt], dim=-1).reshape(B * T, -1)
|
||
feat = self.torsion_pair_in_norm(feat)
|
||
out = self.torsion_pair_mlp(feat).squeeze(-1).view(B, T)
|
||
out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
|
||
c = self.torsion_subgraph_output_clip
|
||
if c > 0:
|
||
out = out.clamp(min=-c, max=c)
|
||
return out
|
||
|
||
def forward(self, x: Data, time: torch.Tensor) -> RigidTorsionVelocity:
|
||
if self.model_type == "gcn":
|
||
h = x.coords
|
||
edge_index = x.edge_index
|
||
batch = x.batch
|
||
for idx, conv in enumerate(self.convs):
|
||
h_new = self.act(conv(h, edge_index))
|
||
if self.gcn_residual and h_new.shape == h.shape:
|
||
h = h_new + self.gcn_skip[idx](h)
|
||
else:
|
||
h = h_new
|
||
if self._pool is None:
|
||
trunk = global_mean_pool(h, batch)
|
||
else:
|
||
trunk = self._pool(h, batch)
|
||
else:
|
||
bsz = time.shape[0]
|
||
coords = x.coords.reshape(bsz, self.num_nodes * 3)
|
||
trunk = self.encoder(coords)
|
||
|
||
t = time.to(dtype=trunk.dtype, device=trunk.device)
|
||
two_pi_t = 2.0 * torch.pi * t
|
||
time_feats = torch.cat([t, torch.sin(two_pi_t), torch.cos(two_pi_t)], dim=-1)
|
||
pooled = self.shared_norm(torch.cat([trunk, time_feats], dim=-1))
|
||
trans = self.trans_head(pooled)
|
||
rot_omega = self.rot_head(pooled)
|
||
if self.torsion_head_mode == "bond_pair":
|
||
assert self.model_type == "gcn"
|
||
torsion = self._torsion_from_subgraph_specs(x, pooled, h)
|
||
else:
|
||
assert self.torsion_head is not None
|
||
torsion = self.torsion_head(pooled)
|
||
return RigidTorsionVelocity(center=trans, rot_omega=rot_omega, torsion=torsion)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Data loading
|
||
# ---------------------------------------------------------------------------
|
||
def load_sample_from_pickle(path: str, pos: int) -> Data:
|
||
with open(path, "rb") as f:
|
||
f.seek(pos)
|
||
raw = pickle.load(f)
|
||
raw.edge_index = raw.edge_index[:2, :]
|
||
d = Data(coords=raw.coords, edge_index=raw.edge_index)
|
||
d.num_nodes = int(raw.coords.shape[0])
|
||
d.rdmol = raw.rdmol
|
||
for name in ("gt_coords", "gt_trans", "gt_quaternion", "gt_torsion", "bond_endpoints", "movable_mask"):
|
||
if hasattr(raw, name):
|
||
setattr(d, name, getattr(raw, name))
|
||
return d
|
||
|
||
|
||
def load_sample_from_sdf(sdf_path: str) -> Data:
|
||
mol = Chem.MolFromMolFile(sdf_path, removeHs=False)
|
||
if mol is None:
|
||
raise RuntimeError(f"RDKit could not read {sdf_path}")
|
||
conf = mol.GetConformer()
|
||
n = mol.GetNumAtoms()
|
||
coords = torch.tensor(
|
||
[[conf.GetAtomPosition(i).x, conf.GetAtomPosition(i).y, conf.GetAtomPosition(i).z] for i in range(n)],
|
||
dtype=torch.float32,
|
||
)
|
||
edges: list[tuple[int, int]] = []
|
||
for b in mol.GetBonds():
|
||
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
|
||
edges.append((i, j))
|
||
edges.append((j, i))
|
||
edge_index = torch.tensor(edges, dtype=torch.long).T.contiguous()
|
||
data = Data()
|
||
data.coords = coords
|
||
data.edge_index = edge_index
|
||
data.num_nodes = int(coords.shape[0])
|
||
data.rdmol = Chem.Mol(mol)
|
||
return data
|
||
|
||
|
||
def prepare_sample(sample: Data) -> Data:
|
||
sample.num_nodes = int(sample.coords.shape[0])
|
||
sample.gt_coords = sample.coords.clone()
|
||
sample.gt_trans = torch.mean(sample.coords, dim=0, dtype=torch.float32)
|
||
sample.gt_quaternion = torch.tensor((0, 0, 0, 0), dtype=torch.float32)
|
||
sample.bond_endpoints, sample.movable_mask = bfs_torsion_masks(sample)
|
||
sample.gt_torsion = torch.zeros(sample.bond_endpoints.shape[0], dtype=torch.float32)
|
||
return sample
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# CPU batch building (DataLoader workers or main process)
|
||
# ---------------------------------------------------------------------------
|
||
def _gt_state_cpu(gs: RigidTorsionState) -> RigidTorsionState:
|
||
return RigidTorsionState(
|
||
center=gs.center.cpu(),
|
||
rot=gs.rot.cpu(),
|
||
torsion=gs.torsion.cpu(),
|
||
)
|
||
|
||
|
||
def make_rfm_example(
|
||
sample: Data,
|
||
gt_state_cpu: RigidTorsionState,
|
||
cpu: torch.device,
|
||
time_power: float = 1.0,
|
||
) -> dict[str, Any]:
|
||
"""One random pose + velocity targets (CPU tensors). Safe for DataLoader workers."""
|
||
random_trans = generate_random_trans(
|
||
torch.tensor([0, 0, 0], dtype=torch.float32, device=cpu),
|
||
torch.tensor([3, 3, 3], dtype=torch.float32, device=cpu),
|
||
)
|
||
random_rot = generate_rotation_matrix(generate_random_quaternion())
|
||
random_torsion = generate_random_torsion_angles(sample.bond_endpoints.shape[0], device=cpu)
|
||
cur = RigidTorsionState(center=random_trans, rot=random_rot, torsion=random_torsion)
|
||
velocity = rigid_torsion_velocity(cur, gt_state_cpu, torch.tensor(0.0, device=cpu))
|
||
t_scalar = torch.rand((), device=cpu) ** time_power
|
||
new_state = get_state(cur, velocity, float(t_scalar.item()))
|
||
|
||
g = deepcopy(sample)
|
||
g.coords = g.coords.to(device=cpu, dtype=torch.float32)
|
||
g.edge_index = g.edge_index.to(device=cpu)
|
||
if hasattr(g, "movable_mask"):
|
||
g.movable_mask = g.movable_mask.to(device=cpu)
|
||
if hasattr(g, "bond_endpoints"):
|
||
g.bond_endpoints = g.bond_endpoints.to(device=cpu)
|
||
|
||
apply_trans(g, new_state.center)
|
||
apply_rot_matrix(g, new_state.rot)
|
||
apply_torsion(g, new_state.torsion, g.movable_mask, g.bond_endpoints)
|
||
|
||
return {
|
||
"data": g,
|
||
"v_center": velocity.center.detach().cpu(),
|
||
"v_omega": velocity.rot_omega.detach().cpu(),
|
||
"v_torsion": velocity.torsion.detach().cpu(),
|
||
"time": t_scalar.detach().cpu(),
|
||
}
|
||
|
||
|
||
def rfm_collate(batch: list[dict[str, Any]]) -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
graphs = [b["data"] for b in batch]
|
||
batched = Batch.from_data_list(graphs)
|
||
time_b = torch.stack([b["time"] for b in batch], dim=0).unsqueeze(-1)
|
||
target_center = torch.stack([b["v_center"] for b in batch], dim=0)
|
||
target_omega = torch.stack([b["v_omega"] for b in batch], dim=0)
|
||
target_torsion = torch.stack([b["v_torsion"] for b in batch], dim=0)
|
||
return batched, time_b, target_center, target_omega, target_torsion
|
||
|
||
|
||
class SameGraphRandomRFMDataset(Dataset):
|
||
"""Same topology (`sample`); each __getitem__ is a new random RFM instance (CPU)."""
|
||
|
||
def __init__(
|
||
self,
|
||
sample: Data,
|
||
gt_state_cpu: RigidTorsionState,
|
||
length: int,
|
||
time_power: float = 1.0,
|
||
):
|
||
self.sample = sample
|
||
self.gt_state = _gt_state_cpu(gt_state_cpu)
|
||
self.length = length
|
||
self.time_power = time_power
|
||
|
||
def __len__(self) -> int:
|
||
return self.length
|
||
|
||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||
del idx
|
||
return make_rfm_example(
|
||
self.sample,
|
||
self.gt_state,
|
||
torch.device("cpu"),
|
||
self.time_power,
|
||
)
|
||
|
||
|
||
def build_one_batch(
|
||
sample: Data,
|
||
gt_state: RigidTorsionState,
|
||
batch_size: int,
|
||
cpu: torch.device,
|
||
time_power: float = 1.0,
|
||
) -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""Single-process batch (no workers); same tensors as DataLoader+collate."""
|
||
rows = [
|
||
make_rfm_example(sample, gt_state, cpu, time_power=time_power)
|
||
for _ in range(batch_size)
|
||
]
|
||
return rfm_collate(rows)
|
||
|
||
|
||
def infinite_batches(loader: DataLoader):
|
||
while True:
|
||
for batch in loader:
|
||
yield batch
|
||
|
||
|
||
def main() -> int:
|
||
here = os.path.dirname(os.path.abspath(__file__))
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--pickle-path", default="", help="Optional pickle shard path")
|
||
parser.add_argument("--pickle-pos", type=int, default=61908729)
|
||
parser.add_argument("--sdf", default=os.path.join(here, "sample.sdf"), help="Fallback ligand SDF")
|
||
parser.add_argument("--out-dir", default=os.path.join(here, "rfm_out"))
|
||
parser.add_argument("--epochs", type=int, default=2000)
|
||
parser.add_argument("--batch-size", type=int, default=32)
|
||
parser.add_argument("--lr", type=float, default=1e-3)
|
||
parser.add_argument("--weight-center", type=float, default=1.0, help="Loss weight for translation velocity.")
|
||
parser.add_argument("--weight-omega", type=float, default=2.0, help="Loss weight for rotation velocity.")
|
||
parser.add_argument("--weight-torsion", type=float, default=4.0, help="Loss weight for torsion velocity.")
|
||
parser.add_argument(
|
||
"--tail-risk-weight",
|
||
type=float,
|
||
default=0.0,
|
||
help="Extra weight on upper-quantile per-sample loss (0 disables).",
|
||
)
|
||
parser.add_argument(
|
||
"--tail-risk-quantile",
|
||
type=float,
|
||
default=0.9,
|
||
help="Upper quantile threshold for tail-risk penalty.",
|
||
)
|
||
parser.add_argument("--grad-clip", type=float, default=1.0, help="Gradient clipping max norm.")
|
||
parser.add_argument("--seed", type=int, default=0)
|
||
parser.add_argument("--hidden", type=int, default=128, help="GCN hidden dim (larger => more GPU work)")
|
||
parser.add_argument(
|
||
"--model-type",
|
||
type=str,
|
||
default="mlp",
|
||
choices=["mlp", "gcn"],
|
||
help="Backbone encoder for velocity prediction.",
|
||
)
|
||
parser.add_argument(
|
||
"--gcn-layers",
|
||
type=int,
|
||
default=4,
|
||
help="Number of GCNConv layers (more => longer GPU kernels)",
|
||
)
|
||
parser.add_argument(
|
||
"--accum",
|
||
type=int,
|
||
default=1,
|
||
help="Gradient accumulation steps per optimizer step (keeps GPU busier vs CPU batch build)",
|
||
)
|
||
parser.add_argument(
|
||
"--compile",
|
||
action="store_true",
|
||
help="torch.compile(model) when supported (PyTorch 2+)",
|
||
)
|
||
parser.add_argument(
|
||
"--num-workers",
|
||
type=int,
|
||
default=-1,
|
||
help="DataLoader workers; -1 = min(8, CPU-1), 0 = main process only",
|
||
)
|
||
parser.add_argument(
|
||
"--prefetch-factor",
|
||
type=int,
|
||
default=4,
|
||
help="Per-worker prefetch queue depth (ignored if num-workers=0)",
|
||
)
|
||
parser.add_argument(
|
||
"--pin-memory",
|
||
action="store_true",
|
||
help="DataLoader pin_memory (faster H2D; can hit shm/pin issues in some setups)",
|
||
)
|
||
parser.add_argument(
|
||
"--sync-batch",
|
||
action="store_true",
|
||
help="Build batches in main thread (no DataLoader workers)",
|
||
)
|
||
parser.add_argument(
|
||
"--dataset-len",
|
||
type=int,
|
||
default=1_000_000,
|
||
help="Synthetic Dataset length (only affects shuffle / epoch size)",
|
||
)
|
||
parser.add_argument(
|
||
"--eval-runs",
|
||
type=int,
|
||
default=100,
|
||
help="Number of rollout tests for final RMSD report",
|
||
)
|
||
parser.add_argument(
|
||
"--rollout-steps",
|
||
type=int,
|
||
default=100,
|
||
help="Integration steps per rollout during final test",
|
||
)
|
||
parser.add_argument(
|
||
"--time-power",
|
||
type=float,
|
||
default=1.0,
|
||
help="Sample t as U(0,1)^time_power; >1 biases training toward t~0 states.",
|
||
)
|
||
parser.add_argument(
|
||
"--loss-domain",
|
||
type=str,
|
||
default="displacement",
|
||
choices=["velocity", "displacement"],
|
||
help="Train on raw velocity or displacement-to-target (v * (1 - t)).",
|
||
)
|
||
parser.add_argument(
|
||
"--rotation-loss",
|
||
type=str,
|
||
default="mse",
|
||
choices=["mse", "geodesic", "hybrid"],
|
||
help="Rotation-channel loss type.",
|
||
)
|
||
parser.add_argument(
|
||
"--rotation-hybrid-alpha",
|
||
type=float,
|
||
default=0.5,
|
||
help="For rotation-loss=hybrid: weight on geodesic term (0..1).",
|
||
)
|
||
parser.add_argument(
|
||
"--rotation-weight-start",
|
||
type=float,
|
||
default=-1.0,
|
||
help="If >=0, linearly anneal effective rotation weight from start to --weight-omega.",
|
||
)
|
||
parser.add_argument(
|
||
"--rotation-weight-warmup-epochs",
|
||
type=int,
|
||
default=0,
|
||
help="Warmup epochs for rotation weight schedule.",
|
||
)
|
||
parser.add_argument(
|
||
"--gcn-residual",
|
||
action="store_true",
|
||
help="Enable residual skip in GCN trunk.",
|
||
)
|
||
parser.add_argument(
|
||
"--channel-layernorm",
|
||
action="store_true",
|
||
help="Apply layernorm to pooled trunk and head hidden layers.",
|
||
)
|
||
parser.add_argument(
|
||
"--head-mlp-layers",
|
||
type=int,
|
||
default=1,
|
||
help="Per-channel head depth (>=1).",
|
||
)
|
||
parser.add_argument(
|
||
"--graph-readout",
|
||
type=str,
|
||
default="mean",
|
||
choices=["mean", "attention"],
|
||
help="GCN graph readout: mean pool vs learned attention pool (ignored for model-type=mlp).",
|
||
)
|
||
parser.add_argument(
|
||
"--torsion-head",
|
||
type=str,
|
||
default="global",
|
||
choices=["global", "bond_pair"],
|
||
help=(
|
||
"global: torsion vector from pooled trunk+time. "
|
||
"bond_pair (GCN only): induced subgraph = movable ∪ bond (i,j); subgraph node input = concat(coords, "
|
||
"full-graph trunk node h); dedicated sub_convs; batched B×T forward; MLP on [global pooled || subgraph pool]."
|
||
),
|
||
)
|
||
parser.add_argument(
|
||
"--subgraph-lr-scale",
|
||
type=float,
|
||
default=0.3,
|
||
help="bond_pair: LR multiplier for subgraph encoder + torsion head (main trunk uses --lr).",
|
||
)
|
||
parser.add_argument(
|
||
"--subgraph-torsion-clip",
|
||
type=float,
|
||
default=6.0,
|
||
help="bond_pair: clamp predicted torsion channel to [-c, c]; use 0 to disable.",
|
||
)
|
||
parser.add_argument(
|
||
"--torsion-sub-gcn-layers",
|
||
type=int,
|
||
default=-1,
|
||
help="bond_pair: subgraph GCN depth (-1 = same as --gcn-layers).",
|
||
)
|
||
parser.add_argument(
|
||
"--early-stop-patience",
|
||
type=int,
|
||
default=0,
|
||
help="Train-loss early stop patience in checks (0 disables).",
|
||
)
|
||
parser.add_argument(
|
||
"--early-stop-min-delta",
|
||
type=float,
|
||
default=0.0,
|
||
help="Minimum train-loss improvement required to reset patience.",
|
||
)
|
||
parser.add_argument(
|
||
"--early-stop-check-every",
|
||
type=int,
|
||
default=1,
|
||
help="Evaluate early-stop condition every N epochs.",
|
||
)
|
||
parser.add_argument(
|
||
"--early-stop-warmup",
|
||
type=int,
|
||
default=0,
|
||
help="Skip early-stop checks for first N epochs.",
|
||
)
|
||
parser.add_argument(
|
||
"--omega-max-norm",
|
||
type=float,
|
||
default=0.0,
|
||
help="If >0, clip predicted rotation omega vector norm for stability.",
|
||
)
|
||
parser.add_argument(
|
||
"--ema-decay",
|
||
type=float,
|
||
default=0.0,
|
||
help=(
|
||
"If in (0,1), maintain an EMA shadow of model weights after each optimizer step; "
|
||
"best_train checkpoint and final eval load the EMA weights (training uses online weights)."
|
||
),
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
torch.manual_seed(args.seed)
|
||
if torch.cuda.is_available():
|
||
torch.backends.cudnn.benchmark = True
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
cpu = torch.device("cpu")
|
||
|
||
if args.pickle_path and os.path.isfile(args.pickle_path):
|
||
print("Loading sample from pickle:", args.pickle_path, "pos", args.pickle_pos)
|
||
sample = load_sample_from_pickle(args.pickle_path, args.pickle_pos)
|
||
elif os.path.isfile(args.sdf):
|
||
print("Loading sample from SDF:", args.sdf)
|
||
sample = load_sample_from_sdf(args.sdf)
|
||
else:
|
||
print("No pickle or SDF found. Provide --pickle-path or --sdf", file=sys.stderr)
|
||
return 1
|
||
|
||
sample = prepare_sample(sample)
|
||
torsion_head_mode = args.torsion_head
|
||
if torsion_head_mode == "bond_pair" and args.model_type != "gcn":
|
||
print(
|
||
"Warning: --torsion-head bond_pair requires --model-type gcn; using global torsion head.",
|
||
file=sys.stderr,
|
||
)
|
||
torsion_head_mode = "global"
|
||
os.makedirs(args.out_dir, exist_ok=True)
|
||
save_visualization(sample, os.path.join(args.out_dir, "00_gt_coords.sdf"))
|
||
|
||
gt_state = RigidTorsionState(
|
||
center=torch.tensor([0, 0, 0], dtype=torch.float32, device=cpu),
|
||
rot=generate_rotation_matrix(torch.tensor([0, 0, 0, 0], dtype=torch.float32)),
|
||
torsion=torch.zeros(sample.bond_endpoints.shape[0], dtype=torch.float32),
|
||
)
|
||
|
||
target_state = RigidTorsionState(
|
||
center=torch.tensor([0, 0, 0], dtype=torch.float32, device=cpu),
|
||
rot=generate_rotation_matrix(sample.gt_quaternion),
|
||
torsion=sample.gt_torsion,
|
||
)
|
||
|
||
torsion_subgraphs: nn.ModuleList | None = None
|
||
if torsion_head_mode == "bond_pair":
|
||
torsion_subgraphs = build_torsion_subgraph_specs(sample)
|
||
|
||
torsion_sub_gcn_layers = None if int(args.torsion_sub_gcn_layers) < 0 else int(args.torsion_sub_gcn_layers)
|
||
|
||
model = RFMModel(
|
||
model_type=args.model_type,
|
||
num_nodes=int(sample.num_nodes),
|
||
num_torsions=int(sample.bond_endpoints.shape[0]),
|
||
hidden=args.hidden,
|
||
num_layers=args.gcn_layers,
|
||
gcn_residual=args.gcn_residual,
|
||
channel_layernorm=args.channel_layernorm,
|
||
head_mlp_layers=args.head_mlp_layers,
|
||
graph_readout=args.graph_readout,
|
||
torsion_head_mode=torsion_head_mode,
|
||
torsion_subgraphs=torsion_subgraphs,
|
||
torsion_subgraph_output_clip=float(args.subgraph_torsion_clip),
|
||
torsion_sub_gcn_layers=torsion_sub_gcn_layers,
|
||
).to(device)
|
||
subgraph_param_ids: set[int] = set()
|
||
if torsion_head_mode == "bond_pair":
|
||
for name, p in model.named_parameters():
|
||
if not p.requires_grad:
|
||
continue
|
||
if name.startswith(("sub_convs", "sub_gcn_skip", "torsion_pair", "sub_pool_norm")):
|
||
subgraph_param_ids.add(id(p))
|
||
if args.compile and hasattr(torch, "compile"):
|
||
model = torch.compile(model) # type: ignore[assignment]
|
||
if torsion_head_mode == "bond_pair":
|
||
base_params: list[nn.Parameter] = []
|
||
sub_params: list[nn.Parameter] = []
|
||
for p in model.parameters():
|
||
if not p.requires_grad:
|
||
continue
|
||
if id(p) in subgraph_param_ids:
|
||
sub_params.append(p)
|
||
else:
|
||
base_params.append(p)
|
||
optimizer = torch.optim.Adam(
|
||
[
|
||
{"params": base_params, "lr": args.lr},
|
||
{"params": sub_params, "lr": args.lr * float(args.subgraph_lr_scale)},
|
||
],
|
||
)
|
||
else:
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||
|
||
def model_core(m: nn.Module) -> nn.Module:
|
||
return m._orig_mod if hasattr(m, "_orig_mod") else m
|
||
|
||
ema_decay = float(args.ema_decay)
|
||
ema_state: dict[str, torch.Tensor] | None = None
|
||
if 0.0 < ema_decay < 1.0:
|
||
core0 = model_core(model)
|
||
ema_state = {k: v.detach().clone() for k, v in core0.state_dict().items()}
|
||
|
||
@torch.no_grad()
|
||
def ema_update_(shadow: dict[str, torch.Tensor], m: nn.Module, decay: float) -> None:
|
||
src = model_core(m).state_dict()
|
||
for k, dst in shadow.items():
|
||
t = src[k]
|
||
if torch.is_floating_point(t):
|
||
dst.mul_(decay).add_(t, alpha=1.0 - decay)
|
||
else:
|
||
dst.copy_(t)
|
||
|
||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||
print(
|
||
f"device={device} model={args.model_type} hidden={args.hidden} "
|
||
f"layers={args.gcn_layers} params={n_params:,}"
|
||
)
|
||
print(f"time sampling power={args.time_power}")
|
||
print(
|
||
"loss weights: "
|
||
f"center={args.weight_center} omega={args.weight_omega} torsion={args.weight_torsion}; "
|
||
f"grad_clip={args.grad_clip}"
|
||
)
|
||
print(f"tail_risk: weight={args.tail_risk_weight} quantile={args.tail_risk_quantile}")
|
||
print("loss domain=%s torsion_wrapped_loss=always_on" % args.loss_domain)
|
||
print(
|
||
f"rotation_loss={args.rotation_loss} gcn_residual={args.gcn_residual} "
|
||
f"channel_layernorm={args.channel_layernorm} head_mlp_layers={args.head_mlp_layers} "
|
||
f"graph_readout={args.graph_readout} torsion_head={torsion_head_mode}"
|
||
)
|
||
if torsion_head_mode == "bond_pair":
|
||
sub_l = args.gcn_layers if int(args.torsion_sub_gcn_layers) < 0 else int(args.torsion_sub_gcn_layers)
|
||
print(
|
||
f"bond_pair_stab: subgraph_lr_scale={args.subgraph_lr_scale} "
|
||
f"subgraph_torsion_clip={args.subgraph_torsion_clip} "
|
||
f"torsion_sub_gcn_layers={sub_l}"
|
||
)
|
||
if args.rotation_loss == "hybrid":
|
||
print(f"rotation_hybrid_alpha={args.rotation_hybrid_alpha}")
|
||
if args.rotation_weight_start >= 0:
|
||
print(
|
||
"rotation_weight_schedule: "
|
||
f"start={args.rotation_weight_start} end={args.weight_omega} "
|
||
f"warmup_epochs={args.rotation_weight_warmup_epochs}"
|
||
)
|
||
print(
|
||
"early_stop: "
|
||
f"patience={args.early_stop_patience} min_delta={args.early_stop_min_delta} "
|
||
f"check_every={args.early_stop_check_every} warmup={args.early_stop_warmup}"
|
||
)
|
||
print(f"omega_max_norm={args.omega_max_norm}")
|
||
if ema_state is not None:
|
||
print(f"ema_decay={ema_decay} (best checkpoint + final eval use EMA weights)")
|
||
|
||
nw = args.num_workers
|
||
if nw < 0:
|
||
nw = min(8, max(0, (os.cpu_count() or 4) - 1))
|
||
pf = args.prefetch_factor if nw > 0 else None
|
||
persistent = nw > 0
|
||
|
||
if args.sync_batch:
|
||
|
||
def get_train_batch() -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
return build_one_batch(
|
||
sample,
|
||
gt_state,
|
||
args.batch_size,
|
||
cpu,
|
||
args.time_power,
|
||
)
|
||
|
||
print(f"batching: sync (main thread) workers=0")
|
||
else:
|
||
dataset = SameGraphRandomRFMDataset(
|
||
sample,
|
||
gt_state_cpu=gt_state,
|
||
length=args.dataset_len,
|
||
time_power=args.time_power,
|
||
)
|
||
loader = DataLoader(
|
||
dataset,
|
||
batch_size=args.batch_size,
|
||
shuffle=True,
|
||
num_workers=nw,
|
||
prefetch_factor=pf,
|
||
persistent_workers=persistent,
|
||
pin_memory=args.pin_memory and torch.cuda.is_available(),
|
||
collate_fn=rfm_collate,
|
||
)
|
||
_batch_iter = infinite_batches(loader)
|
||
|
||
def get_train_batch() -> tuple[Batch, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
return next(_batch_iter)
|
||
|
||
print(
|
||
f"batching: DataLoader num_workers={nw} prefetch_factor={pf!s} "
|
||
f"pin_memory={args.pin_memory and torch.cuda.is_available()} "
|
||
f"persistent_workers={persistent}"
|
||
)
|
||
|
||
t0 = time.perf_counter()
|
||
log_every = max(1, args.epochs // 10)
|
||
accum = max(1, args.accum)
|
||
best_train_mse = float("inf")
|
||
best_state: dict[str, torch.Tensor] | None = None
|
||
early_stop_best = float("inf")
|
||
early_stop_bad_checks = 0
|
||
stop_reason = "max_epochs"
|
||
stop_epoch = args.epochs - 1
|
||
|
||
for epoch in range(args.epochs):
|
||
optimizer.zero_grad(set_to_none=True)
|
||
sum_mse = torch.zeros((), device=device)
|
||
for _ in range(accum):
|
||
batched, time_b, tc, tw, tt = get_train_batch()
|
||
batched = batched.to(device, non_blocking=True)
|
||
time_b = time_b.to(device=device, dtype=torch.float32, non_blocking=True)
|
||
tc = tc.to(device=device, dtype=torch.float32, non_blocking=True)
|
||
tw = tw.to(device=device, dtype=torch.float32, non_blocking=True)
|
||
tt = tt.to(device=device, dtype=torch.float32, non_blocking=True)
|
||
|
||
pred = model(batched, time_b)
|
||
time_remain = (1.0 - time_b).clamp(min=1e-6)
|
||
if args.loss_domain == "displacement":
|
||
pred_center = pred.center * time_remain
|
||
pred_omega = clip_vector_norm_batch(pred.rot_omega, args.omega_max_norm) * time_remain
|
||
pred_torsion = pred.torsion * time_remain
|
||
tgt_center = tc * time_remain
|
||
tgt_omega = tw * time_remain
|
||
tgt_torsion = tt * time_remain
|
||
else:
|
||
pred_center = pred.center
|
||
pred_omega = clip_vector_norm_batch(pred.rot_omega, args.omega_max_norm)
|
||
pred_torsion = pred.torsion
|
||
tgt_center = tc
|
||
tgt_omega = tw
|
||
tgt_torsion = tt
|
||
|
||
eps = 1e-8
|
||
scale_c = (tgt_center * tgt_center).mean().clamp(min=eps)
|
||
scale_w = (tgt_omega * tgt_omega).mean().clamp(min=eps)
|
||
scale_t = (tgt_torsion * tgt_torsion).mean().clamp(min=eps)
|
||
per_center = ((pred_center - tgt_center) ** 2).mean(dim=-1) / scale_c
|
||
if args.rotation_loss == "geodesic":
|
||
rot_loss = so3_geodesic_mse_loss(pred_omega, tgt_omega)
|
||
per_rot = ((pred_omega - tgt_omega) ** 2).mean(dim=-1) / scale_w
|
||
elif args.rotation_loss == "hybrid":
|
||
alpha = float(max(0.0, min(1.0, args.rotation_hybrid_alpha)))
|
||
rot_loss = (
|
||
alpha * so3_geodesic_mse_loss(pred_omega, tgt_omega)
|
||
+ (1.0 - alpha) * F.mse_loss(pred_omega, tgt_omega)
|
||
)
|
||
per_rot = ((pred_omega - tgt_omega) ** 2).mean(dim=-1) / scale_w
|
||
else:
|
||
rot_loss = F.mse_loss(pred_omega, tgt_omega)
|
||
per_rot = ((pred_omega - tgt_omega) ** 2).mean(dim=-1) / scale_w
|
||
per_torsion = (wrap_angle(pred_torsion - tgt_torsion) ** 2).mean(dim=-1) / scale_t
|
||
current_weight_omega = (
|
||
args.weight_omega
|
||
if args.rotation_weight_start < 0
|
||
else (
|
||
args.rotation_weight_start
|
||
+ (
|
||
max(0.0, min(1.0, epoch / max(1, args.rotation_weight_warmup_epochs)))
|
||
* (args.weight_omega - args.rotation_weight_start)
|
||
)
|
||
)
|
||
)
|
||
mse = (
|
||
args.weight_center * per_center.mean()
|
||
+ current_weight_omega * (rot_loss / scale_w)
|
||
+ args.weight_torsion * per_torsion.mean()
|
||
)
|
||
if args.tail_risk_weight > 0:
|
||
per_sample_total = (
|
||
args.weight_center * per_center
|
||
+ current_weight_omega * per_rot
|
||
+ args.weight_torsion * per_torsion
|
||
)
|
||
mse = mse + args.tail_risk_weight * upper_quantile_mean(
|
||
per_sample_total, args.tail_risk_quantile
|
||
)
|
||
if not torch.isfinite(mse):
|
||
# Keep graph connectivity while preventing non-finite backward.
|
||
mse = torch.nan_to_num(mse, nan=1e3, posinf=1e3, neginf=1e3)
|
||
(mse / accum).backward()
|
||
sum_mse = sum_mse + mse.detach()
|
||
|
||
if args.grad_clip > 0:
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||
optimizer.step()
|
||
if ema_state is not None:
|
||
ema_update_(ema_state, model, ema_decay)
|
||
|
||
avg_mse = sum_mse / accum
|
||
avg_mse_val = float(avg_mse.detach().cpu())
|
||
if avg_mse_val < best_train_mse:
|
||
best_train_mse = avg_mse_val
|
||
core = model_core(model)
|
||
if ema_state is not None:
|
||
best_state = {k: v.detach().cpu().clone() for k, v in ema_state.items()}
|
||
else:
|
||
best_state = {k: v.detach().cpu().clone() for k, v in core.state_dict().items()}
|
||
|
||
should_check_early_stop = (
|
||
args.early_stop_patience > 0
|
||
and epoch >= args.early_stop_warmup
|
||
and ((epoch - args.early_stop_warmup) % max(1, args.early_stop_check_every) == 0)
|
||
)
|
||
if should_check_early_stop:
|
||
improved = avg_mse_val < (early_stop_best - args.early_stop_min_delta)
|
||
if improved:
|
||
early_stop_best = avg_mse_val
|
||
early_stop_bad_checks = 0
|
||
else:
|
||
early_stop_bad_checks += 1
|
||
if early_stop_bad_checks >= args.early_stop_patience:
|
||
stop_reason = (
|
||
"early_stop_train_loss_plateau"
|
||
f"(patience={args.early_stop_patience},"
|
||
f"min_delta={args.early_stop_min_delta},"
|
||
f"check_every={max(1, args.early_stop_check_every)})"
|
||
)
|
||
stop_epoch = epoch
|
||
print(
|
||
f"early stopping at epoch {epoch}: "
|
||
f"bad_checks={early_stop_bad_checks}"
|
||
)
|
||
break
|
||
|
||
if epoch % log_every == 0 or epoch == args.epochs - 1:
|
||
print(f"epoch {epoch:6d} train_mse {avg_mse_val:.6f} best_train_mse {best_train_mse:.6f}")
|
||
|
||
stop_epoch = epoch
|
||
|
||
if device.type == "cuda":
|
||
torch.cuda.synchronize()
|
||
print(
|
||
f"training done in {time.perf_counter() - t0:.1f}s "
|
||
f"(stop_reason={stop_reason}, stop_epoch={stop_epoch})"
|
||
)
|
||
|
||
ckpt_path = ""
|
||
if best_state is not None:
|
||
core = model_core(model)
|
||
core.load_state_dict(best_state)
|
||
print(f"loaded best model from training (best_train_mse={best_train_mse:.6f})")
|
||
artifacts_dir = os.path.join(here, "artifacts")
|
||
os.makedirs(artifacts_dir, exist_ok=True)
|
||
ckpt_path = os.path.join(artifacts_dir, "latest_eval_best_model.pt")
|
||
ckpt = {
|
||
"state_dict": best_state,
|
||
"model_type": args.model_type,
|
||
"hidden": int(args.hidden),
|
||
"num_layers": int(args.gcn_layers),
|
||
"num_nodes": int(sample.num_nodes),
|
||
"num_torsions": int(sample.bond_endpoints.shape[0]),
|
||
"sample_sdf": args.sdf,
|
||
"sample_pickle_path": args.pickle_path,
|
||
"sample_pickle_pos": int(args.pickle_pos),
|
||
"eval_runs": int(args.eval_runs),
|
||
"rollout_steps": int(args.rollout_steps),
|
||
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
|
||
# RFMModel forward parity (required for trajectory / tooling reloads)
|
||
"gcn_residual": bool(args.gcn_residual),
|
||
"channel_layernorm": bool(args.channel_layernorm),
|
||
"head_mlp_layers": int(args.head_mlp_layers),
|
||
"graph_readout": str(args.graph_readout),
|
||
"torsion_head": str(torsion_head_mode),
|
||
"subgraph_torsion_clip": float(args.subgraph_torsion_clip),
|
||
"torsion_sub_gcn_layers": int(args.torsion_sub_gcn_layers),
|
||
"omega_max_norm": float(args.omega_max_norm),
|
||
}
|
||
torch.save(ckpt, ckpt_path)
|
||
print(f"saved best checkpoint: {ckpt_path}")
|
||
|
||
# Final test metric: 100 rollouts mean final-structure RMSD
|
||
final_mean_rmsd_100 = evaluate_mean_rmsd_100(
|
||
model,
|
||
sample,
|
||
sample.gt_coords,
|
||
device=device,
|
||
num_runs=args.eval_runs,
|
||
rollout_steps=args.rollout_steps,
|
||
omega_max_norm=args.omega_max_norm,
|
||
)
|
||
print(f"final_test_mean_rmsd_{args.eval_runs}: {final_mean_rmsd_100:.6f}")
|
||
|
||
reports_dir = os.path.join(here, "reports")
|
||
os.makedirs(reports_dir, exist_ok=True)
|
||
latest_path = os.path.join(reports_dir, "latest_eval.json")
|
||
timestamp = datetime.now(timezone.utc).isoformat()
|
||
cmd = " ".join(sys.argv)
|
||
|
||
latest_report = {
|
||
"mean_rmsd_100": float(final_mean_rmsd_100),
|
||
"num_runs": int(args.eval_runs),
|
||
"timestamp_utc": timestamp,
|
||
"command": cmd,
|
||
"notes": "Final test metric from 100 random-initialized rollouts to time=1.",
|
||
"best_train_mse": float(best_train_mse),
|
||
"model_source": "best_train_checkpoint",
|
||
"checkpoint_path": ckpt_path,
|
||
"stop_reason": stop_reason,
|
||
"stop_epoch": int(stop_epoch),
|
||
"ema_decay": float(ema_decay) if ema_state is not None else 0.0,
|
||
}
|
||
with open(latest_path, "w", encoding="utf-8") as f:
|
||
json.dump(latest_report, f, indent=2)
|
||
|
||
# --- Visualizations ---
|
||
sample.coords = sample.gt_coords.clone()
|
||
save_visualization(sample, os.path.join(args.out_dir, "01_gt_reset.sdf"))
|
||
|
||
random_trans = generate_random_trans(
|
||
torch.tensor([0, 0, 0], dtype=torch.float32), torch.tensor([3, 3, 3], dtype=torch.float32)
|
||
)
|
||
random_rot = generate_rotation_matrix(generate_random_quaternion())
|
||
random_torsion = generate_random_torsion_angles(sample.bond_endpoints.shape[0])
|
||
random_state = RigidTorsionState(center=random_trans, rot=random_rot, torsion=random_torsion)
|
||
|
||
vis = deepcopy(sample)
|
||
vis.coords = vis.gt_coords.clone()
|
||
apply_trans(vis, random_trans)
|
||
apply_rot_matrix(vis, random_rot)
|
||
apply_torsion(vis, random_torsion, vis.movable_mask, vis.bond_endpoints)
|
||
save_visualization(vis, os.path.join(args.out_dir, "02_random_init.sdf"))
|
||
|
||
velocity_gt = rigid_torsion_velocity(random_state, target_state, torch.tensor(0.0))
|
||
state_list_gt = [get_state(random_state, velocity_gt, float(t)) for t in torch.linspace(0, 1, 50)]
|
||
save_video_visualization(sample, state_list_gt, os.path.join(args.out_dir, "03_trajectory_gt_velocity.sdf"))
|
||
|
||
# Learned trajectory: Euler steps using model velocity (batch size 1)
|
||
model.eval()
|
||
state = RigidTorsionState(
|
||
center=random_state.center.to(device),
|
||
rot=random_state.rot.to(device),
|
||
torsion=random_state.torsion.to(device),
|
||
)
|
||
n_frames = 100
|
||
learned_states: list[RigidTorsionState] = []
|
||
for k in range(n_frames):
|
||
learned_states.append(
|
||
RigidTorsionState(
|
||
center=state.center.detach().cpu(),
|
||
rot=state.rot.detach().cpu(),
|
||
torsion=state.torsion.detach().cpu(),
|
||
)
|
||
)
|
||
t_val = k / max(1, n_frames - 1)
|
||
g = deepcopy(sample)
|
||
g.coords = sample.gt_coords.clone()
|
||
apply_trans(g, state.center.cpu())
|
||
apply_rot_matrix(g, state.rot.cpu())
|
||
apply_torsion(g, state.torsion.cpu(), g.movable_mask, g.bond_endpoints)
|
||
g = g.to(device)
|
||
batch = Batch.from_data_list([g]).to(device)
|
||
tb = torch.tensor([[t_val]], device=device, dtype=torch.float32)
|
||
with torch.no_grad():
|
||
v = model(batch, tb)
|
||
vc = v.center.squeeze(0)
|
||
vw = clip_vector_norm_batch(v.rot_omega, args.omega_max_norm).squeeze(0)
|
||
vt = v.torsion.squeeze(0)
|
||
dt = 1.0 / max(1, n_frames - 1)
|
||
state = RigidTorsionState(
|
||
center=state.center + vc * dt,
|
||
rot=state.rot @ so3_exp(vw * dt),
|
||
torsion=state.torsion + vt * dt,
|
||
)
|
||
|
||
save_video_visualization(sample, learned_states, os.path.join(args.out_dir, "04_trajectory_learned_model.sdf"))
|
||
|
||
print("Wrote SDFs under:", args.out_dir)
|
||
for name in sorted(os.listdir(args.out_dir)):
|
||
if name.endswith(".sdf"):
|
||
print(" ", name)
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|