Add bond-conditioned torsion head and log first eval.
Per-torsion predictions use bond endpoints and movable masks with pooled context; document smoke runs and persist latest eval artifacts. Made-with: Cursor
This commit is contained in:
@@ -99,3 +99,4 @@ This repository is intentionally pinned to CUDA 12.6 PyTorch wheels and matching
|
||||
- 2026-04-16: Strategy S4 micro-tuning #3 softened tail coverage (`tail-risk-quantile=0.9`, `tail-risk-weight=0.2`, `lr=6.8e-4`) and improved to `mean_rmsd_100=2.440570`, but still below best `2.388103`.
|
||||
- 2026-04-16: Strategy S4 micro-tuning #4 increased tail penalty (`tail-risk-weight=0.25`, quantile `0.9`) and regressed sharply to `mean_rmsd_100=2.601258`; indicates over-penalization risk.
|
||||
- 2026-04-16: Strategy S4 micro-tuning #5 changed seed (`seed=2`, `tail-risk-weight=0.2`, quantile `0.9`) and encountered prolonged fallback-to-1000 behavior with `mean_rmsd_100=2.709563`; S4 hit 5-run cap with no best update.
|
||||
- 2026-04-16: Structural torsion head (`--torsion-head bond_pair`, GCN only): translation/rotation still use mean-pooled trunk+time; each torsion channel is predicted from that same global context plus bond endpoint node embeddings, `movable_mask`-weighted fragment pool, bond vector, and endpoint mask bits. Added `LayerNorm` on the concatenated bond-pair feature vector for scale. Training still hit long `train_mse=1000` plateaus; best eval in quick trials was `mean_rmsd_100=2.514426` before that norm and `2.537496` after—both worse than best `2.388103`, so the hypothesis needs more stabilization or a different pairing signal rather than more depth alone.
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
{
|
||||
"mean_rmsd_100": 2.7095625364780425,
|
||||
"mean_rmsd_100": 2.537496319413185,
|
||||
"num_runs": 100,
|
||||
"timestamp_utc": "2026-04-16T14:39:33.444352+00:00",
|
||||
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 6.8e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --tail-risk-weight 0.2 --tail-risk-quantile 0.9 --seed 2",
|
||||
"timestamp_utc": "2026-04-16T14:52:02.649041+00:00",
|
||||
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 6.8e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --tail-risk-weight 0.0 --torsion-head bond_pair --seed 1",
|
||||
"notes": "Final test metric from 100 random-initialized rollouts to time=1.",
|
||||
"best_train_mse": 9.438053131103516,
|
||||
"best_train_mse": 6.513882160186768,
|
||||
"model_source": "best_train_checkpoint",
|
||||
"checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/latest_eval_best_model.pt",
|
||||
"stop_reason": "max_epochs",
|
||||
|
||||
108
train.py
108
train.py
@@ -462,6 +462,7 @@ class RFMModel(nn.Module):
|
||||
channel_layernorm: bool = False,
|
||||
head_mlp_layers: int = 1,
|
||||
graph_readout: str = "mean",
|
||||
torsion_head_mode: str = "global",
|
||||
):
|
||||
super().__init__()
|
||||
self.model_type = model_type
|
||||
@@ -474,6 +475,7 @@ class RFMModel(nn.Module):
|
||||
self.act = nn.SiLU()
|
||||
self.head_mlp_layers = max(1, int(head_mlp_layers))
|
||||
self.graph_readout = graph_readout
|
||||
self.torsion_head_mode = torsion_head_mode
|
||||
|
||||
if model_type == "gcn":
|
||||
self.convs = nn.ModuleList()
|
||||
@@ -521,7 +523,84 @@ class RFMModel(nn.Module):
|
||||
|
||||
self.trans_head = make_head(3)
|
||||
self.rot_head = make_head(3)
|
||||
self.torsion_head = make_head(num_torsions)
|
||||
if torsion_head_mode == "global":
|
||||
self.torsion_head = make_head(num_torsions)
|
||||
self.torsion_pair_mlp = None
|
||||
self.torsion_pair_in_norm = None
|
||||
elif torsion_head_mode == "bond_pair":
|
||||
if model_type != "gcn":
|
||||
raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn")
|
||||
# pooled || h_i || h_j || movable-side pool || rel_ij || m_i || m_j
|
||||
pair_in = d + 3 * hidden + 3 + 2
|
||||
ph = max(64, min(256, hidden * 2))
|
||||
self.torsion_head = None
|
||||
self.torsion_pair_in_norm = nn.LayerNorm(pair_in)
|
||||
self.torsion_pair_mlp = nn.Sequential(
|
||||
nn.Linear(pair_in, ph),
|
||||
nn.SiLU(),
|
||||
nn.Linear(ph, ph),
|
||||
nn.SiLU(),
|
||||
nn.Linear(ph, 1),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported torsion_head_mode: {torsion_head_mode}")
|
||||
|
||||
def _torsion_from_bond_pair(
|
||||
self,
|
||||
h: torch.Tensor,
|
||||
x: Data,
|
||||
pooled: torch.Tensor,
|
||||
time: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Per-torsion velocity from bond endpoints + movable mask (GCN batched graphs)."""
|
||||
del time
|
||||
if self.num_torsions == 0:
|
||||
return pooled.new_zeros((pooled.shape[0], 0))
|
||||
assert self.torsion_pair_mlp is not None and self.torsion_pair_in_norm is not None
|
||||
if not hasattr(x, "bond_endpoints") or not hasattr(x, "movable_mask"):
|
||||
raise RuntimeError("bond_pair torsion head requires bond_endpoints and movable_mask on graph batch")
|
||||
if not hasattr(x, "ptr") or x.ptr is None:
|
||||
raise RuntimeError("bond_pair torsion head requires Batch.ptr (use Batch.from_data_list)")
|
||||
|
||||
B = int(pooled.shape[0])
|
||||
T = self.num_torsions
|
||||
device = h.device
|
||||
dtype = h.dtype
|
||||
be = x.bond_endpoints.to(device=device, dtype=torch.long)
|
||||
if be.shape[0] != B * T:
|
||||
raise RuntimeError(
|
||||
f"bond_endpoints rows {be.shape[0]} != B*T {B * T}; "
|
||||
"graphs must share the same torsion count"
|
||||
)
|
||||
ptr = x.ptr.to(device=device)
|
||||
N = int(ptr[1].item() - ptr[0].item())
|
||||
b_idx = torch.arange(B, device=device, dtype=torch.long).repeat_interleave(T)
|
||||
base = ptr[b_idx]
|
||||
i_loc = be[:, 0]
|
||||
j_loc = be[:, 1]
|
||||
gid_i = base + i_loc
|
||||
gid_j = base + j_loc
|
||||
hi = h[gid_i]
|
||||
hj = h[gid_j]
|
||||
rel = x.coords[gid_i] - x.coords[gid_j]
|
||||
|
||||
mv = x.movable_mask.to(device=device, dtype=dtype)
|
||||
if mv.shape != (B * T, N):
|
||||
raise RuntimeError(f"movable_mask shape {tuple(mv.shape)} expected ({B * T}, {N})")
|
||||
n_idx = torch.arange(N, device=device, dtype=torch.long).unsqueeze(0).expand(B * T, -1)
|
||||
gather_idx = base.unsqueeze(1) + n_idx
|
||||
h_side = h[gather_idx]
|
||||
mv_e = mv.unsqueeze(-1)
|
||||
denom = mv_e.sum(dim=1).clamp(min=1e-6)
|
||||
side_pool = (h_side * mv_e).sum(dim=1) / denom
|
||||
|
||||
ar = torch.arange(B * T, device=device, dtype=torch.long)
|
||||
mi = mv[ar, i_loc].unsqueeze(-1)
|
||||
mj = mv[ar, j_loc].unsqueeze(-1)
|
||||
pooled_rep = pooled[b_idx]
|
||||
feat = self.torsion_pair_in_norm(torch.cat([pooled_rep, hi, hj, side_pool, rel, mi, mj], dim=-1))
|
||||
out = self.torsion_pair_mlp(feat).squeeze(-1)
|
||||
return out.view(B, T)
|
||||
|
||||
def forward(self, x: Data, time: torch.Tensor) -> RigidTorsionVelocity:
|
||||
if self.model_type == "gcn":
|
||||
@@ -549,7 +628,12 @@ class RFMModel(nn.Module):
|
||||
pooled = self.shared_norm(torch.cat([trunk, time_feats], dim=-1))
|
||||
trans = self.trans_head(pooled)
|
||||
rot_omega = self.rot_head(pooled)
|
||||
torsion = self.torsion_head(pooled)
|
||||
if self.torsion_head_mode == "bond_pair":
|
||||
assert self.model_type == "gcn"
|
||||
torsion = self._torsion_from_bond_pair(h, x, pooled, time)
|
||||
else:
|
||||
assert self.torsion_head is not None
|
||||
torsion = self.torsion_head(pooled)
|
||||
return RigidTorsionVelocity(center=trans, rot_omega=rot_omega, torsion=torsion)
|
||||
|
||||
|
||||
@@ -866,6 +950,16 @@ def main() -> int:
|
||||
choices=["mean", "attention"],
|
||||
help="GCN graph readout: mean pool vs learned attention pool (ignored for model-type=mlp).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torsion-head",
|
||||
type=str,
|
||||
default="global",
|
||||
choices=["global", "bond_pair"],
|
||||
help=(
|
||||
"global: torsion vector from pooled trunk+time. "
|
||||
"bond_pair (GCN only): per-torsion MLP using bond_endpoints, movable_mask, and node embeddings."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early-stop-patience",
|
||||
type=int,
|
||||
@@ -915,6 +1009,13 @@ def main() -> int:
|
||||
return 1
|
||||
|
||||
sample = prepare_sample(sample)
|
||||
torsion_head_mode = args.torsion_head
|
||||
if torsion_head_mode == "bond_pair" and args.model_type != "gcn":
|
||||
print(
|
||||
"Warning: --torsion-head bond_pair requires --model-type gcn; using global torsion head.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
torsion_head_mode = "global"
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
save_visualization(sample, os.path.join(args.out_dir, "00_gt_coords.sdf"))
|
||||
|
||||
@@ -940,6 +1041,7 @@ def main() -> int:
|
||||
channel_layernorm=args.channel_layernorm,
|
||||
head_mlp_layers=args.head_mlp_layers,
|
||||
graph_readout=args.graph_readout,
|
||||
torsion_head_mode=torsion_head_mode,
|
||||
).to(device)
|
||||
if args.compile and hasattr(torch, "compile"):
|
||||
model = torch.compile(model) # type: ignore[assignment]
|
||||
@@ -961,7 +1063,7 @@ def main() -> int:
|
||||
print(
|
||||
f"rotation_loss={args.rotation_loss} gcn_residual={args.gcn_residual} "
|
||||
f"channel_layernorm={args.channel_layernorm} head_mlp_layers={args.head_mlp_layers} "
|
||||
f"graph_readout={args.graph_readout}"
|
||||
f"graph_readout={args.graph_readout} torsion_head={torsion_head_mode}"
|
||||
)
|
||||
if args.rotation_loss == "hybrid":
|
||||
print(f"rotation_hybrid_alpha={args.rotation_hybrid_alpha}")
|
||||
|
||||
Reference in New Issue
Block a user