Add bond-conditioned torsion head and log first eval.

Per-torsion predictions use bond endpoints and movable masks with pooled context; document smoke runs and persist latest eval artifacts.

Made-with: Cursor
This commit is contained in:
demian3b
2026-04-16 23:52:25 +09:00
parent fe8303ed49
commit 2a5f867839
3 changed files with 110 additions and 7 deletions

View File

@@ -99,3 +99,4 @@ This repository is intentionally pinned to CUDA 12.6 PyTorch wheels and matching
- 2026-04-16: Strategy S4 micro-tuning #3 softened tail coverage (`tail-risk-quantile=0.9`, `tail-risk-weight=0.2`, `lr=6.8e-4`) and improved to `mean_rmsd_100=2.440570`, but still below best `2.388103`.
- 2026-04-16: Strategy S4 micro-tuning #4 increased tail penalty (`tail-risk-weight=0.25`, quantile `0.9`) and regressed sharply to `mean_rmsd_100=2.601258`; indicates over-penalization risk.
- 2026-04-16: Strategy S4 micro-tuning #5 changed seed (`seed=2`, `tail-risk-weight=0.2`, quantile `0.9`) and encountered prolonged fallback-to-1000 behavior with `mean_rmsd_100=2.709563`; S4 hit 5-run cap with no best update.
- 2026-04-16: Structural torsion head (`--torsion-head bond_pair`, GCN only): translation/rotation still use mean-pooled trunk+time; each torsion channel is predicted from that same global context plus bond endpoint node embeddings, `movable_mask`-weighted fragment pool, bond vector, and endpoint mask bits. Added `LayerNorm` on the concatenated bond-pair feature vector for scale. Training still hit long `train_mse=1000` plateaus; best eval in quick trials was `mean_rmsd_100=2.514426` before that norm and `2.537496` after—both worse than best `2.388103`, so the hypothesis needs more stabilization or a different pairing signal rather than more depth alone.

View File

@@ -1,10 +1,10 @@
{
"mean_rmsd_100": 2.7095625364780425,
"mean_rmsd_100": 2.537496319413185,
"num_runs": 100,
"timestamp_utc": "2026-04-16T14:39:33.444352+00:00",
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 6.8e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --tail-risk-weight 0.2 --tail-risk-quantile 0.9 --seed 2",
"timestamp_utc": "2026-04-16T14:52:02.649041+00:00",
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 6.8e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --tail-risk-weight 0.0 --torsion-head bond_pair --seed 1",
"notes": "Final test metric from 100 random-initialized rollouts to time=1.",
"best_train_mse": 9.438053131103516,
"best_train_mse": 6.513882160186768,
"model_source": "best_train_checkpoint",
"checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/latest_eval_best_model.pt",
"stop_reason": "max_epochs",

108
train.py
View File

@@ -462,6 +462,7 @@ class RFMModel(nn.Module):
channel_layernorm: bool = False,
head_mlp_layers: int = 1,
graph_readout: str = "mean",
torsion_head_mode: str = "global",
):
super().__init__()
self.model_type = model_type
@@ -474,6 +475,7 @@ class RFMModel(nn.Module):
self.act = nn.SiLU()
self.head_mlp_layers = max(1, int(head_mlp_layers))
self.graph_readout = graph_readout
self.torsion_head_mode = torsion_head_mode
if model_type == "gcn":
self.convs = nn.ModuleList()
@@ -521,7 +523,84 @@ class RFMModel(nn.Module):
self.trans_head = make_head(3)
self.rot_head = make_head(3)
self.torsion_head = make_head(num_torsions)
if torsion_head_mode == "global":
self.torsion_head = make_head(num_torsions)
self.torsion_pair_mlp = None
self.torsion_pair_in_norm = None
elif torsion_head_mode == "bond_pair":
if model_type != "gcn":
raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn")
# pooled || h_i || h_j || movable-side pool || rel_ij || m_i || m_j
pair_in = d + 3 * hidden + 3 + 2
ph = max(64, min(256, hidden * 2))
self.torsion_head = None
self.torsion_pair_in_norm = nn.LayerNorm(pair_in)
self.torsion_pair_mlp = nn.Sequential(
nn.Linear(pair_in, ph),
nn.SiLU(),
nn.Linear(ph, ph),
nn.SiLU(),
nn.Linear(ph, 1),
)
else:
raise ValueError(f"Unsupported torsion_head_mode: {torsion_head_mode}")
def _torsion_from_bond_pair(
self,
h: torch.Tensor,
x: Data,
pooled: torch.Tensor,
time: torch.Tensor,
) -> torch.Tensor:
"""Per-torsion velocity from bond endpoints + movable mask (GCN batched graphs)."""
del time
if self.num_torsions == 0:
return pooled.new_zeros((pooled.shape[0], 0))
assert self.torsion_pair_mlp is not None and self.torsion_pair_in_norm is not None
if not hasattr(x, "bond_endpoints") or not hasattr(x, "movable_mask"):
raise RuntimeError("bond_pair torsion head requires bond_endpoints and movable_mask on graph batch")
if not hasattr(x, "ptr") or x.ptr is None:
raise RuntimeError("bond_pair torsion head requires Batch.ptr (use Batch.from_data_list)")
B = int(pooled.shape[0])
T = self.num_torsions
device = h.device
dtype = h.dtype
be = x.bond_endpoints.to(device=device, dtype=torch.long)
if be.shape[0] != B * T:
raise RuntimeError(
f"bond_endpoints rows {be.shape[0]} != B*T {B * T}; "
"graphs must share the same torsion count"
)
ptr = x.ptr.to(device=device)
N = int(ptr[1].item() - ptr[0].item())
b_idx = torch.arange(B, device=device, dtype=torch.long).repeat_interleave(T)
base = ptr[b_idx]
i_loc = be[:, 0]
j_loc = be[:, 1]
gid_i = base + i_loc
gid_j = base + j_loc
hi = h[gid_i]
hj = h[gid_j]
rel = x.coords[gid_i] - x.coords[gid_j]
mv = x.movable_mask.to(device=device, dtype=dtype)
if mv.shape != (B * T, N):
raise RuntimeError(f"movable_mask shape {tuple(mv.shape)} expected ({B * T}, {N})")
n_idx = torch.arange(N, device=device, dtype=torch.long).unsqueeze(0).expand(B * T, -1)
gather_idx = base.unsqueeze(1) + n_idx
h_side = h[gather_idx]
mv_e = mv.unsqueeze(-1)
denom = mv_e.sum(dim=1).clamp(min=1e-6)
side_pool = (h_side * mv_e).sum(dim=1) / denom
ar = torch.arange(B * T, device=device, dtype=torch.long)
mi = mv[ar, i_loc].unsqueeze(-1)
mj = mv[ar, j_loc].unsqueeze(-1)
pooled_rep = pooled[b_idx]
feat = self.torsion_pair_in_norm(torch.cat([pooled_rep, hi, hj, side_pool, rel, mi, mj], dim=-1))
out = self.torsion_pair_mlp(feat).squeeze(-1)
return out.view(B, T)
def forward(self, x: Data, time: torch.Tensor) -> RigidTorsionVelocity:
if self.model_type == "gcn":
@@ -549,7 +628,12 @@ class RFMModel(nn.Module):
pooled = self.shared_norm(torch.cat([trunk, time_feats], dim=-1))
trans = self.trans_head(pooled)
rot_omega = self.rot_head(pooled)
torsion = self.torsion_head(pooled)
if self.torsion_head_mode == "bond_pair":
assert self.model_type == "gcn"
torsion = self._torsion_from_bond_pair(h, x, pooled, time)
else:
assert self.torsion_head is not None
torsion = self.torsion_head(pooled)
return RigidTorsionVelocity(center=trans, rot_omega=rot_omega, torsion=torsion)
@@ -866,6 +950,16 @@ def main() -> int:
choices=["mean", "attention"],
help="GCN graph readout: mean pool vs learned attention pool (ignored for model-type=mlp).",
)
parser.add_argument(
"--torsion-head",
type=str,
default="global",
choices=["global", "bond_pair"],
help=(
"global: torsion vector from pooled trunk+time. "
"bond_pair (GCN only): per-torsion MLP using bond_endpoints, movable_mask, and node embeddings."
),
)
parser.add_argument(
"--early-stop-patience",
type=int,
@@ -915,6 +1009,13 @@ def main() -> int:
return 1
sample = prepare_sample(sample)
torsion_head_mode = args.torsion_head
if torsion_head_mode == "bond_pair" and args.model_type != "gcn":
print(
"Warning: --torsion-head bond_pair requires --model-type gcn; using global torsion head.",
file=sys.stderr,
)
torsion_head_mode = "global"
os.makedirs(args.out_dir, exist_ok=True)
save_visualization(sample, os.path.join(args.out_dir, "00_gt_coords.sdf"))
@@ -940,6 +1041,7 @@ def main() -> int:
channel_layernorm=args.channel_layernorm,
head_mlp_layers=args.head_mlp_layers,
graph_readout=args.graph_readout,
torsion_head_mode=torsion_head_mode,
).to(device)
if args.compile and hasattr(torch, "compile"):
model = torch.compile(model) # type: ignore[assignment]
@@ -961,7 +1063,7 @@ def main() -> int:
print(
f"rotation_loss={args.rotation_loss} gcn_residual={args.gcn_residual} "
f"channel_layernorm={args.channel_layernorm} head_mlp_layers={args.head_mlp_layers} "
f"graph_readout={args.graph_readout}"
f"graph_readout={args.graph_readout} torsion_head={torsion_head_mode}"
)
if args.rotation_loss == "hybrid":
print(f"rotation_hybrid_alpha={args.rotation_hybrid_alpha}")