Merge branch 'attempt/s3-restart-after-doc-sync' into post-main-doc-sync
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
{
|
||||
"mean_rmsd_100": 2.464729991555214,
|
||||
"mean_rmsd_100": 2.598530296087265,
|
||||
"num_runs": 100,
|
||||
"timestamp_utc": "2026-04-16T14:27:02.733292+00:00",
|
||||
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 6.8e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --seed 1",
|
||||
"timestamp_utc": "2026-04-16T14:56:28.000610+00:00",
|
||||
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 6.8e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --tail-risk-weight 0.0 --torsion-head bond_pair --seed 1",
|
||||
"notes": "Final test metric from 100 random-initialized rollouts to time=1.",
|
||||
"best_train_mse": 3.5417799949645996,
|
||||
"best_train_mse": 6.003782272338867,
|
||||
"model_source": "best_train_checkpoint",
|
||||
"checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/latest_eval_best_model.pt",
|
||||
"stop_reason": "max_epochs",
|
||||
|
||||
312
train.py
312
train.py
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
@@ -26,7 +27,8 @@ from rdkit.Chem import SDWriter
|
||||
from rdkit.Geometry.rdGeometry import Point3D
|
||||
from torch import nn
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.nn import GCNConv, global_mean_pool
|
||||
from torch_geometric.nn import GCNConv, global_add_pool, global_mean_pool
|
||||
from torch_geometric.utils import softmax
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -207,6 +209,19 @@ def torsion_wrapped_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.
|
||||
return torch.mean(wrap_angle(pred - target) ** 2)
|
||||
|
||||
|
||||
def upper_quantile_mean(values: torch.Tensor, quantile: float) -> torch.Tensor:
|
||||
"""Mean of upper-quantile tail values."""
|
||||
q = float(max(0.0, min(1.0, quantile)))
|
||||
if q <= 0.0:
|
||||
return values.mean()
|
||||
if q >= 1.0:
|
||||
return values.max()
|
||||
n = int(values.numel())
|
||||
k = max(1, int(math.ceil(n * (1.0 - q))))
|
||||
top_vals, _ = torch.topk(values.reshape(-1), k=k, largest=True)
|
||||
return top_vals.mean()
|
||||
|
||||
|
||||
def so3_geodesic_mse_loss(pred_omega: torch.Tensor, target_omega: torch.Tensor) -> torch.Tensor:
|
||||
"""Geodesic MSE between batched axis-angle increments."""
|
||||
if pred_omega.ndim != 2 or pred_omega.shape[-1] != 3:
|
||||
@@ -223,6 +238,14 @@ def so3_geodesic_mse_loss(pred_omega: torch.Tensor, target_omega: torch.Tensor)
|
||||
return torch.stack(losses).mean()
|
||||
|
||||
|
||||
def clip_vector_norm_batch(x: torch.Tensor, max_norm: float) -> torch.Tensor:
|
||||
if max_norm <= 0:
|
||||
return x
|
||||
norms = torch.linalg.norm(x, dim=-1, keepdim=True).clamp(min=1e-9)
|
||||
scale = torch.clamp(max_norm / norms, max=1.0)
|
||||
return x * scale
|
||||
|
||||
|
||||
def skew(w: torch.Tensor) -> torch.Tensor:
|
||||
z = torch.zeros((), device=w.device, dtype=w.dtype)
|
||||
return torch.stack(
|
||||
@@ -365,6 +388,7 @@ def evaluate_mean_rmsd_100(
|
||||
device: torch.device,
|
||||
num_runs: int = 100,
|
||||
rollout_steps: int = 100,
|
||||
omega_max_norm: float = 0.0,
|
||||
) -> float:
|
||||
"""Run 100 one-shot predictions to time=1 and average final-structure RMSD."""
|
||||
del rollout_steps
|
||||
@@ -390,7 +414,7 @@ def evaluate_mean_rmsd_100(
|
||||
with torch.no_grad():
|
||||
v = model(batch, tb)
|
||||
vc = v.center.detach().cpu()
|
||||
vw = v.rot_omega.detach().cpu()
|
||||
vw = clip_vector_norm_batch(v.rot_omega.detach().cpu(), omega_max_norm)
|
||||
vt = v.torsion.detach().cpu()
|
||||
center_t = center_t + vc
|
||||
torsion_t = torsion_t + vt
|
||||
@@ -407,6 +431,68 @@ def evaluate_mean_rmsd_100(
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model
|
||||
# ---------------------------------------------------------------------------
|
||||
class TorsionSubgraphSpec(nn.Module):
|
||||
"""Movable-side induced subgraph for one torsion (full-graph node ids + local edge_index)."""
|
||||
|
||||
def __init__(self, nodes_full: torch.Tensor, edge_index_local: torch.Tensor):
|
||||
super().__init__()
|
||||
self.register_buffer("nodes_full", nodes_full, persistent=True)
|
||||
self.register_buffer("edge_index_local", edge_index_local, persistent=True)
|
||||
|
||||
|
||||
def build_torsion_subgraph_specs(sample: Data) -> nn.ModuleList:
|
||||
"""Induced subgraph on atoms where movable_mask[k] is True for each torsion k."""
|
||||
n = int(sample.num_nodes)
|
||||
ei = sample.edge_index
|
||||
if ei.device.type != "cpu":
|
||||
ei = ei.cpu()
|
||||
mv = sample.movable_mask
|
||||
if mv.device.type != "cpu":
|
||||
mv = mv.cpu()
|
||||
t_count = int(sample.bond_endpoints.shape[0])
|
||||
modules: list[TorsionSubgraphSpec] = []
|
||||
for k in range(t_count):
|
||||
mask_row = mv[k].bool()
|
||||
nodes_full = torch.where(mask_row)[0].long()
|
||||
if nodes_full.numel() == 0:
|
||||
raise RuntimeError(f"torsion {k}: empty movable side; cannot build subgraph")
|
||||
node_set = set(int(x) for x in nodes_full.tolist())
|
||||
old_to_new = {old: i for i, old in enumerate(sorted(node_set))}
|
||||
es: list[int] = []
|
||||
ed: list[int] = []
|
||||
for e in range(ei.shape[1]):
|
||||
a, b = int(ei[0, e]), int(ei[1, e])
|
||||
if a in node_set and b in node_set:
|
||||
es.append(old_to_new[a])
|
||||
ed.append(old_to_new[b])
|
||||
if len(es) > 0:
|
||||
edge_index_local = torch.tensor([es, ed], dtype=torch.long)
|
||||
else:
|
||||
nk = nodes_full.numel()
|
||||
edge_index_local = torch.stack(
|
||||
[torch.arange(nk, dtype=torch.long), torch.arange(nk, dtype=torch.long)], dim=0
|
||||
)
|
||||
modules.append(TorsionSubgraphSpec(nodes_full, edge_index_local))
|
||||
return nn.ModuleList(modules)
|
||||
|
||||
|
||||
class AttentionGraphReadout(nn.Module):
|
||||
"""Learned softmax weights over nodes (distinct inductive bias vs mean pooling)."""
|
||||
|
||||
def __init__(self, node_dim: int, gate_hidden: int = 64):
|
||||
super().__init__()
|
||||
self.gate = nn.Sequential(
|
||||
nn.Linear(node_dim, gate_hidden),
|
||||
nn.SiLU(),
|
||||
nn.Linear(gate_hidden, 1),
|
||||
)
|
||||
|
||||
def forward(self, h: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
|
||||
scores = self.gate(h).squeeze(-1)
|
||||
alpha = softmax(scores, batch)
|
||||
return global_add_pool(h * alpha.unsqueeze(-1), batch)
|
||||
|
||||
|
||||
class RFMModel(nn.Module):
|
||||
"""Model for velocity prediction. Supports GCN and fixed-order MLP encoders."""
|
||||
|
||||
@@ -420,6 +506,9 @@ class RFMModel(nn.Module):
|
||||
gcn_residual: bool = False,
|
||||
channel_layernorm: bool = False,
|
||||
head_mlp_layers: int = 1,
|
||||
graph_readout: str = "mean",
|
||||
torsion_head_mode: str = "global",
|
||||
torsion_subgraphs: nn.ModuleList | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_type = model_type
|
||||
@@ -431,6 +520,8 @@ class RFMModel(nn.Module):
|
||||
self.channel_layernorm = channel_layernorm
|
||||
self.act = nn.SiLU()
|
||||
self.head_mlp_layers = max(1, int(head_mlp_layers))
|
||||
self.graph_readout = graph_readout
|
||||
self.torsion_head_mode = torsion_head_mode
|
||||
|
||||
if model_type == "gcn":
|
||||
self.convs = nn.ModuleList()
|
||||
@@ -441,7 +532,14 @@ class RFMModel(nn.Module):
|
||||
for _ in range(num_layers - 1):
|
||||
self.gcn_skip.append(nn.Identity())
|
||||
trunk_dim = hidden
|
||||
if graph_readout == "mean":
|
||||
self._pool = None
|
||||
elif graph_readout == "attention":
|
||||
self._pool = AttentionGraphReadout(trunk_dim, gate_hidden=max(32, min(128, hidden)))
|
||||
else:
|
||||
raise ValueError(f"Unsupported graph_readout: {graph_readout}")
|
||||
elif model_type == "mlp":
|
||||
self._pool = None
|
||||
mlp_layers = []
|
||||
in_dim = num_nodes * 3
|
||||
for _ in range(num_layers):
|
||||
@@ -471,7 +569,69 @@ class RFMModel(nn.Module):
|
||||
|
||||
self.trans_head = make_head(3)
|
||||
self.rot_head = make_head(3)
|
||||
self.torsion_head = make_head(num_torsions)
|
||||
if torsion_head_mode == "global":
|
||||
self.torsion_head = make_head(num_torsions)
|
||||
self.torsion_pair_mlp = None
|
||||
self.torsion_pair_in_norm = None
|
||||
self.torsion_subgraphs = None
|
||||
elif torsion_head_mode == "bond_pair":
|
||||
if model_type != "gcn":
|
||||
raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn")
|
||||
if torsion_subgraphs is None or len(torsion_subgraphs) != num_torsions:
|
||||
raise ValueError("bond_pair requires torsion_subgraphs built from sample.movable_mask")
|
||||
# Global pooled context + mean pool over a GCN run on the movable-side induced subgraph only.
|
||||
pair_in = d + hidden
|
||||
ph = max(64, min(256, hidden * 2))
|
||||
self.torsion_head = None
|
||||
self.torsion_subgraphs = torsion_subgraphs
|
||||
self.torsion_pair_in_norm = nn.LayerNorm(pair_in)
|
||||
self.torsion_pair_mlp = nn.Sequential(
|
||||
nn.Linear(pair_in, ph),
|
||||
nn.SiLU(),
|
||||
nn.Linear(ph, ph),
|
||||
nn.SiLU(),
|
||||
nn.Linear(ph, 1),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported torsion_head_mode: {torsion_head_mode}")
|
||||
|
||||
def _torsion_from_subgraph_specs(self, x: Data, pooled: torch.Tensor) -> torch.Tensor:
|
||||
"""Per-torsion: same GCN weights on movable-side subgraph coords, then pooled + MLP."""
|
||||
if self.num_torsions == 0:
|
||||
return pooled.new_zeros((pooled.shape[0], 0))
|
||||
assert self.torsion_subgraphs is not None
|
||||
assert self.torsion_pair_mlp is not None and self.torsion_pair_in_norm is not None
|
||||
if not hasattr(x, "ptr") or x.ptr is None:
|
||||
raise RuntimeError("bond_pair torsion head requires Batch.ptr (use Batch.from_data_list)")
|
||||
|
||||
B = int(pooled.shape[0])
|
||||
ptr = x.ptr.to(device=pooled.device)
|
||||
device = pooled.device
|
||||
dtype = pooled.dtype
|
||||
outs: list[torch.Tensor] = []
|
||||
for k in range(self.num_torsions):
|
||||
spec = self.torsion_subgraphs[k]
|
||||
nodes_f = spec.nodes_full.to(device=device, dtype=torch.long)
|
||||
ei_loc = spec.edge_index_local.to(device=device)
|
||||
graphs: list[Data] = []
|
||||
for b in range(B):
|
||||
base = int(ptr[b].item())
|
||||
coords_sub = x.coords[base + nodes_f].to(dtype=dtype)
|
||||
graphs.append(Data(x=coords_sub, edge_index=ei_loc))
|
||||
batch_sub = Batch.from_data_list(graphs).to(device)
|
||||
h = batch_sub.x
|
||||
edge_index = batch_sub.edge_index
|
||||
batch_vec = batch_sub.batch
|
||||
for idx, conv in enumerate(self.convs):
|
||||
h_new = self.act(conv(h, edge_index))
|
||||
if self.gcn_residual and h_new.shape == h.shape:
|
||||
h = h_new + self.gcn_skip[idx](h)
|
||||
else:
|
||||
h = h_new
|
||||
sub_trunk = global_mean_pool(h, batch_vec)
|
||||
feat = self.torsion_pair_in_norm(torch.cat([pooled, sub_trunk], dim=-1))
|
||||
outs.append(self.torsion_pair_mlp(feat).squeeze(-1))
|
||||
return torch.stack(outs, dim=1)
|
||||
|
||||
def forward(self, x: Data, time: torch.Tensor) -> RigidTorsionVelocity:
|
||||
if self.model_type == "gcn":
|
||||
@@ -484,7 +644,10 @@ class RFMModel(nn.Module):
|
||||
h = h_new + self.gcn_skip[idx](h)
|
||||
else:
|
||||
h = h_new
|
||||
trunk = global_mean_pool(h, batch)
|
||||
if self._pool is None:
|
||||
trunk = global_mean_pool(h, batch)
|
||||
else:
|
||||
trunk = self._pool(h, batch)
|
||||
else:
|
||||
bsz = time.shape[0]
|
||||
coords = x.coords.reshape(bsz, self.num_nodes * 3)
|
||||
@@ -496,7 +659,12 @@ class RFMModel(nn.Module):
|
||||
pooled = self.shared_norm(torch.cat([trunk, time_feats], dim=-1))
|
||||
trans = self.trans_head(pooled)
|
||||
rot_omega = self.rot_head(pooled)
|
||||
torsion = self.torsion_head(pooled)
|
||||
if self.torsion_head_mode == "bond_pair":
|
||||
assert self.model_type == "gcn"
|
||||
torsion = self._torsion_from_subgraph_specs(x, pooled)
|
||||
else:
|
||||
assert self.torsion_head is not None
|
||||
torsion = self.torsion_head(pooled)
|
||||
return RigidTorsionVelocity(center=trans, rot_omega=rot_omega, torsion=torsion)
|
||||
|
||||
|
||||
@@ -673,6 +841,18 @@ def main() -> int:
|
||||
parser.add_argument("--weight-center", type=float, default=1.0, help="Loss weight for translation velocity.")
|
||||
parser.add_argument("--weight-omega", type=float, default=2.0, help="Loss weight for rotation velocity.")
|
||||
parser.add_argument("--weight-torsion", type=float, default=4.0, help="Loss weight for torsion velocity.")
|
||||
parser.add_argument(
|
||||
"--tail-risk-weight",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Extra weight on upper-quantile per-sample loss (0 disables).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tail-risk-quantile",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="Upper quantile threshold for tail-risk penalty.",
|
||||
)
|
||||
parser.add_argument("--grad-clip", type=float, default=1.0, help="Gradient clipping max norm.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--hidden", type=int, default=128, help="GCN hidden dim (larger => more GPU work)")
|
||||
@@ -757,9 +937,27 @@ def main() -> int:
|
||||
"--rotation-loss",
|
||||
type=str,
|
||||
default="mse",
|
||||
choices=["mse", "geodesic"],
|
||||
choices=["mse", "geodesic", "hybrid"],
|
||||
help="Rotation-channel loss type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rotation-hybrid-alpha",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="For rotation-loss=hybrid: weight on geodesic term (0..1).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rotation-weight-start",
|
||||
type=float,
|
||||
default=-1.0,
|
||||
help="If >=0, linearly anneal effective rotation weight from start to --weight-omega.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rotation-weight-warmup-epochs",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Warmup epochs for rotation weight schedule.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gcn-residual",
|
||||
action="store_true",
|
||||
@@ -776,6 +974,24 @@ def main() -> int:
|
||||
default=1,
|
||||
help="Per-channel head depth (>=1).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--graph-readout",
|
||||
type=str,
|
||||
default="mean",
|
||||
choices=["mean", "attention"],
|
||||
help="GCN graph readout: mean pool vs learned attention pool (ignored for model-type=mlp).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torsion-head",
|
||||
type=str,
|
||||
default="global",
|
||||
choices=["global", "bond_pair"],
|
||||
help=(
|
||||
"global: torsion vector from pooled trunk+time. "
|
||||
"bond_pair (GCN only): per torsion k, run the same GCN trunk on the movable-side induced subgraph "
|
||||
"(mask selects nodes/edges only; mask values are not fed as features), mean-pool, then MLP with global pooled context."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early-stop-patience",
|
||||
type=int,
|
||||
@@ -800,6 +1016,12 @@ def main() -> int:
|
||||
default=0,
|
||||
help="Skip early-stop checks for first N epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--omega-max-norm",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="If >0, clip predicted rotation omega vector norm for stability.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
@@ -819,6 +1041,13 @@ def main() -> int:
|
||||
return 1
|
||||
|
||||
sample = prepare_sample(sample)
|
||||
torsion_head_mode = args.torsion_head
|
||||
if torsion_head_mode == "bond_pair" and args.model_type != "gcn":
|
||||
print(
|
||||
"Warning: --torsion-head bond_pair requires --model-type gcn; using global torsion head.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
torsion_head_mode = "global"
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
save_visualization(sample, os.path.join(args.out_dir, "00_gt_coords.sdf"))
|
||||
|
||||
@@ -834,6 +1063,10 @@ def main() -> int:
|
||||
torsion=sample.gt_torsion,
|
||||
)
|
||||
|
||||
torsion_subgraphs: nn.ModuleList | None = None
|
||||
if torsion_head_mode == "bond_pair":
|
||||
torsion_subgraphs = build_torsion_subgraph_specs(sample)
|
||||
|
||||
model = RFMModel(
|
||||
model_type=args.model_type,
|
||||
num_nodes=int(sample.num_nodes),
|
||||
@@ -843,6 +1076,9 @@ def main() -> int:
|
||||
gcn_residual=args.gcn_residual,
|
||||
channel_layernorm=args.channel_layernorm,
|
||||
head_mlp_layers=args.head_mlp_layers,
|
||||
graph_readout=args.graph_readout,
|
||||
torsion_head_mode=torsion_head_mode,
|
||||
torsion_subgraphs=torsion_subgraphs,
|
||||
).to(device)
|
||||
if args.compile and hasattr(torch, "compile"):
|
||||
model = torch.compile(model) # type: ignore[assignment]
|
||||
@@ -859,16 +1095,27 @@ def main() -> int:
|
||||
f"center={args.weight_center} omega={args.weight_omega} torsion={args.weight_torsion}; "
|
||||
f"grad_clip={args.grad_clip}"
|
||||
)
|
||||
print(f"tail_risk: weight={args.tail_risk_weight} quantile={args.tail_risk_quantile}")
|
||||
print("loss domain=%s torsion_wrapped_loss=always_on" % args.loss_domain)
|
||||
print(
|
||||
f"rotation_loss={args.rotation_loss} gcn_residual={args.gcn_residual} "
|
||||
f"channel_layernorm={args.channel_layernorm} head_mlp_layers={args.head_mlp_layers}"
|
||||
f"channel_layernorm={args.channel_layernorm} head_mlp_layers={args.head_mlp_layers} "
|
||||
f"graph_readout={args.graph_readout} torsion_head={torsion_head_mode}"
|
||||
)
|
||||
if args.rotation_loss == "hybrid":
|
||||
print(f"rotation_hybrid_alpha={args.rotation_hybrid_alpha}")
|
||||
if args.rotation_weight_start >= 0:
|
||||
print(
|
||||
"rotation_weight_schedule: "
|
||||
f"start={args.rotation_weight_start} end={args.weight_omega} "
|
||||
f"warmup_epochs={args.rotation_weight_warmup_epochs}"
|
||||
)
|
||||
print(
|
||||
"early_stop: "
|
||||
f"patience={args.early_stop_patience} min_delta={args.early_stop_min_delta} "
|
||||
f"check_every={args.early_stop_check_every} warmup={args.early_stop_warmup}"
|
||||
)
|
||||
print(f"omega_max_norm={args.omega_max_norm}")
|
||||
|
||||
nw = args.num_workers
|
||||
if nw < 0:
|
||||
@@ -941,14 +1188,14 @@ def main() -> int:
|
||||
time_remain = (1.0 - time_b).clamp(min=1e-6)
|
||||
if args.loss_domain == "displacement":
|
||||
pred_center = pred.center * time_remain
|
||||
pred_omega = pred.rot_omega * time_remain
|
||||
pred_omega = clip_vector_norm_batch(pred.rot_omega, args.omega_max_norm) * time_remain
|
||||
pred_torsion = pred.torsion * time_remain
|
||||
tgt_center = tc * time_remain
|
||||
tgt_omega = tw * time_remain
|
||||
tgt_torsion = tt * time_remain
|
||||
else:
|
||||
pred_center = pred.center
|
||||
pred_omega = pred.rot_omega
|
||||
pred_omega = clip_vector_norm_batch(pred.rot_omega, args.omega_max_norm)
|
||||
pred_torsion = pred.torsion
|
||||
tgt_center = tc
|
||||
tgt_omega = tw
|
||||
@@ -958,21 +1205,49 @@ def main() -> int:
|
||||
scale_c = (tgt_center * tgt_center).mean().clamp(min=eps)
|
||||
scale_w = (tgt_omega * tgt_omega).mean().clamp(min=eps)
|
||||
scale_t = (tgt_torsion * tgt_torsion).mean().clamp(min=eps)
|
||||
per_center = ((pred_center - tgt_center) ** 2).mean(dim=-1) / scale_c
|
||||
if args.rotation_loss == "geodesic":
|
||||
rot_loss = so3_geodesic_mse_loss(pred_omega, tgt_omega)
|
||||
per_rot = ((pred_omega - tgt_omega) ** 2).mean(dim=-1) / scale_w
|
||||
elif args.rotation_loss == "hybrid":
|
||||
alpha = float(max(0.0, min(1.0, args.rotation_hybrid_alpha)))
|
||||
rot_loss = (
|
||||
alpha * so3_geodesic_mse_loss(pred_omega, tgt_omega)
|
||||
+ (1.0 - alpha) * F.mse_loss(pred_omega, tgt_omega)
|
||||
)
|
||||
per_rot = ((pred_omega - tgt_omega) ** 2).mean(dim=-1) / scale_w
|
||||
else:
|
||||
rot_loss = F.mse_loss(pred_omega, tgt_omega)
|
||||
mse = (
|
||||
args.weight_center * (F.mse_loss(pred_center, tgt_center) / scale_c)
|
||||
+ args.weight_omega * (rot_loss / scale_w)
|
||||
+ args.weight_torsion
|
||||
* (
|
||||
torsion_wrapped_mse_loss(pred_torsion, tgt_torsion)
|
||||
/ scale_t
|
||||
per_rot = ((pred_omega - tgt_omega) ** 2).mean(dim=-1) / scale_w
|
||||
per_torsion = (wrap_angle(pred_torsion - tgt_torsion) ** 2).mean(dim=-1) / scale_t
|
||||
current_weight_omega = (
|
||||
args.weight_omega
|
||||
if args.rotation_weight_start < 0
|
||||
else (
|
||||
args.rotation_weight_start
|
||||
+ (
|
||||
max(0.0, min(1.0, epoch / max(1, args.rotation_weight_warmup_epochs)))
|
||||
* (args.weight_omega - args.rotation_weight_start)
|
||||
)
|
||||
)
|
||||
)
|
||||
mse = (
|
||||
args.weight_center * per_center.mean()
|
||||
+ current_weight_omega * (rot_loss / scale_w)
|
||||
+ args.weight_torsion * per_torsion.mean()
|
||||
)
|
||||
if args.tail_risk_weight > 0:
|
||||
per_sample_total = (
|
||||
args.weight_center * per_center
|
||||
+ current_weight_omega * per_rot
|
||||
+ args.weight_torsion * per_torsion
|
||||
)
|
||||
mse = mse + args.tail_risk_weight * upper_quantile_mean(
|
||||
per_sample_total, args.tail_risk_quantile
|
||||
)
|
||||
if not torch.isfinite(mse):
|
||||
mse = torch.tensor(1e3, device=device, dtype=torch.float32)
|
||||
# Keep graph connectivity while preventing non-finite backward.
|
||||
mse = torch.nan_to_num(mse, nan=1e3, posinf=1e3, neginf=1e3)
|
||||
(mse / accum).backward()
|
||||
sum_mse = sum_mse + mse.detach()
|
||||
|
||||
@@ -1057,6 +1332,7 @@ def main() -> int:
|
||||
device=device,
|
||||
num_runs=args.eval_runs,
|
||||
rollout_steps=args.rollout_steps,
|
||||
omega_max_norm=args.omega_max_norm,
|
||||
)
|
||||
print(f"final_test_mean_rmsd_{args.eval_runs}: {final_mean_rmsd_100:.6f}")
|
||||
|
||||
@@ -1132,7 +1408,7 @@ def main() -> int:
|
||||
with torch.no_grad():
|
||||
v = model(batch, tb)
|
||||
vc = v.center.squeeze(0)
|
||||
vw = v.rot_omega.squeeze(0)
|
||||
vw = clip_vector_norm_batch(v.rot_omega, args.omega_max_norm).squeeze(0)
|
||||
vt = v.torsion.squeeze(0)
|
||||
dt = 1.0 / max(1, n_frames - 1)
|
||||
state = RigidTorsionState(
|
||||
|
||||
Reference in New Issue
Block a user