bond_pair: fuse full-graph node h into subgraph GCN input.

Subgraph nodes use concat(coords, trunk h) so the fragment encoder sees the same-step global geometry context; first sub_conv is GCNConv(3+hidden, hidden).

Made-with: Cursor
This commit is contained in:
demian3b
2026-04-17 00:46:14 +09:00
parent dac3a163ae
commit 80a3e3a014
2 changed files with 17 additions and 12 deletions

View File

@@ -1,10 +1,10 @@
{
"mean_rmsd_100": 2.60611756503582,
"mean_rmsd_100": 2.620511382818222,
"num_runs": 100,
"timestamp_utc": "2026-04-16T15:34:50.254678+00:00",
"timestamp_utc": "2026-04-16T15:40:21.752941+00:00",
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 5.5e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --tail-risk-weight 0.0 --torsion-head bond_pair --subgraph-lr-scale 0.25 --subgraph-torsion-clip 6 --seed 1",
"notes": "Final test metric from 100 random-initialized rollouts to time=1.",
"best_train_mse": 3.184145927429199,
"best_train_mse": 3.1972899436950684,
"model_source": "best_train_checkpoint",
"checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/latest_eval_best_model.pt",
"stop_reason": "max_epochs",

View File

@@ -591,9 +591,11 @@ class RFMModel(nn.Module):
raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn")
if torsion_subgraphs is None or len(torsion_subgraphs) != num_torsions:
raise ValueError("bond_pair requires torsion_subgraphs built from sample.movable_mask")
# Dedicated subgraph encoder (do not share trunk weights with full-graph convs).
# Dedicated subgraph encoder (weights not shared with full-graph convs).
# First-layer input = coords ⊕ full-graph node embedding (same trunk forward); avoids
# throwing away trunk geometry context when re-encoding the movable fragment only.
self.sub_convs = nn.ModuleList()
self.sub_convs.append(GCNConv(3, hidden))
self.sub_convs.append(GCNConv(3 + hidden, hidden))
for _ in range(num_layers - 1):
self.sub_convs.append(GCNConv(hidden, hidden))
# Global pooled context + mean pool over subgraph GCN.
@@ -620,8 +622,8 @@ class RFMModel(nn.Module):
else:
raise ValueError(f"Unsupported torsion_head_mode: {torsion_head_mode}")
def _torsion_from_subgraph_specs(self, x: Data, pooled: torch.Tensor) -> torch.Tensor:
"""Batched subgraph GCN (separate sub_convs), one forward for all (batch, torsion) pairs."""
def _torsion_from_subgraph_specs(self, x: Data, pooled: torch.Tensor, h_full: torch.Tensor) -> torch.Tensor:
"""Batched subgraph GCN; node input is [coords | h_full] gathered onto each torsion subgraph."""
if self.num_torsions == 0:
return pooled.new_zeros((pooled.shape[0], 0))
assert self.torsion_subgraphs is not None
@@ -642,8 +644,11 @@ class RFMModel(nn.Module):
ei_loc = spec.edge_index_local.to(device=device)
for b in range(B):
base = int(ptr[b].item())
coords_sub = x.coords[base + nodes_f].to(dtype=dtype)
graphs.append(Data(x=coords_sub, edge_index=ei_loc))
idx = base + nodes_f
coords_sub = x.coords[idx].to(dtype=dtype)
h_sub = h_full[idx].to(dtype=dtype)
x_sub = torch.cat([coords_sub, h_sub], dim=-1)
graphs.append(Data(x=x_sub, edge_index=ei_loc))
batch_sub = Batch.from_data_list(graphs).to(device)
h = batch_sub.x
edge_index, _ = add_self_loops(batch_sub.edge_index, num_nodes=batch_sub.num_nodes)
@@ -696,7 +701,7 @@ class RFMModel(nn.Module):
rot_omega = self.rot_head(pooled)
if self.torsion_head_mode == "bond_pair":
assert self.model_type == "gcn"
torsion = self._torsion_from_subgraph_specs(x, pooled)
torsion = self._torsion_from_subgraph_specs(x, pooled, h)
else:
assert self.torsion_head is not None
torsion = self.torsion_head(pooled)
@@ -1023,8 +1028,8 @@ def main() -> int:
choices=["global", "bond_pair"],
help=(
"global: torsion vector from pooled trunk+time. "
"bond_pair (GCN only): induced subgraph = movable atoms bond (i,j); dedicated sub_convs (not weight-tied "
"to full-graph trunk); batched B×T subgraph forward; then MLP on [global pooled || subgraph pooled]."
"bond_pair (GCN only): induced subgraph = movable bond (i,j); subgraph node input = concat(coords, "
"full-graph trunk node h); dedicated sub_convs; batched B×T forward; MLP on [global pooled || subgraph pool]."
),
)
parser.add_argument(