bond_pair: fuse full-graph node h into subgraph GCN input.
Subgraph nodes use concat(coords, trunk h) so the fragment encoder sees the same-step global geometry context; first sub_conv is GCNConv(3+hidden, hidden). Made-with: Cursor
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
{
|
||||
"mean_rmsd_100": 2.60611756503582,
|
||||
"mean_rmsd_100": 2.620511382818222,
|
||||
"num_runs": 100,
|
||||
"timestamp_utc": "2026-04-16T15:34:50.254678+00:00",
|
||||
"timestamp_utc": "2026-04-16T15:40:21.752941+00:00",
|
||||
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --epochs 320 --batch-size 24 --model-type gcn --hidden 512 --gcn-layers 8 --eval-runs 100 --loss-domain displacement --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.7 --rotation-loss geodesic --gcn-residual --lr 5.5e-4 --omega-max-norm 5.0 --rotation-weight-start 1.0 --rotation-weight-warmup-epochs 120 --tail-risk-weight 0.0 --torsion-head bond_pair --subgraph-lr-scale 0.25 --subgraph-torsion-clip 6 --seed 1",
|
||||
"notes": "Final test metric from 100 random-initialized rollouts to time=1.",
|
||||
"best_train_mse": 3.184145927429199,
|
||||
"best_train_mse": 3.1972899436950684,
|
||||
"model_source": "best_train_checkpoint",
|
||||
"checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/latest_eval_best_model.pt",
|
||||
"stop_reason": "max_epochs",
|
||||
|
||||
23
train.py
23
train.py
@@ -591,9 +591,11 @@ class RFMModel(nn.Module):
|
||||
raise ValueError("torsion_head_mode=bond_pair requires model_type=gcn")
|
||||
if torsion_subgraphs is None or len(torsion_subgraphs) != num_torsions:
|
||||
raise ValueError("bond_pair requires torsion_subgraphs built from sample.movable_mask")
|
||||
# Dedicated subgraph encoder (do not share trunk weights with full-graph convs).
|
||||
# Dedicated subgraph encoder (weights not shared with full-graph convs).
|
||||
# First-layer input = coords ⊕ full-graph node embedding (same trunk forward); avoids
|
||||
# throwing away trunk geometry context when re-encoding the movable fragment only.
|
||||
self.sub_convs = nn.ModuleList()
|
||||
self.sub_convs.append(GCNConv(3, hidden))
|
||||
self.sub_convs.append(GCNConv(3 + hidden, hidden))
|
||||
for _ in range(num_layers - 1):
|
||||
self.sub_convs.append(GCNConv(hidden, hidden))
|
||||
# Global pooled context + mean pool over subgraph GCN.
|
||||
@@ -620,8 +622,8 @@ class RFMModel(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Unsupported torsion_head_mode: {torsion_head_mode}")
|
||||
|
||||
def _torsion_from_subgraph_specs(self, x: Data, pooled: torch.Tensor) -> torch.Tensor:
|
||||
"""Batched subgraph GCN (separate sub_convs), one forward for all (batch, torsion) pairs."""
|
||||
def _torsion_from_subgraph_specs(self, x: Data, pooled: torch.Tensor, h_full: torch.Tensor) -> torch.Tensor:
|
||||
"""Batched subgraph GCN; node input is [coords | h_full] gathered onto each torsion subgraph."""
|
||||
if self.num_torsions == 0:
|
||||
return pooled.new_zeros((pooled.shape[0], 0))
|
||||
assert self.torsion_subgraphs is not None
|
||||
@@ -642,8 +644,11 @@ class RFMModel(nn.Module):
|
||||
ei_loc = spec.edge_index_local.to(device=device)
|
||||
for b in range(B):
|
||||
base = int(ptr[b].item())
|
||||
coords_sub = x.coords[base + nodes_f].to(dtype=dtype)
|
||||
graphs.append(Data(x=coords_sub, edge_index=ei_loc))
|
||||
idx = base + nodes_f
|
||||
coords_sub = x.coords[idx].to(dtype=dtype)
|
||||
h_sub = h_full[idx].to(dtype=dtype)
|
||||
x_sub = torch.cat([coords_sub, h_sub], dim=-1)
|
||||
graphs.append(Data(x=x_sub, edge_index=ei_loc))
|
||||
batch_sub = Batch.from_data_list(graphs).to(device)
|
||||
h = batch_sub.x
|
||||
edge_index, _ = add_self_loops(batch_sub.edge_index, num_nodes=batch_sub.num_nodes)
|
||||
@@ -696,7 +701,7 @@ class RFMModel(nn.Module):
|
||||
rot_omega = self.rot_head(pooled)
|
||||
if self.torsion_head_mode == "bond_pair":
|
||||
assert self.model_type == "gcn"
|
||||
torsion = self._torsion_from_subgraph_specs(x, pooled)
|
||||
torsion = self._torsion_from_subgraph_specs(x, pooled, h)
|
||||
else:
|
||||
assert self.torsion_head is not None
|
||||
torsion = self.torsion_head(pooled)
|
||||
@@ -1023,8 +1028,8 @@ def main() -> int:
|
||||
choices=["global", "bond_pair"],
|
||||
help=(
|
||||
"global: torsion vector from pooled trunk+time. "
|
||||
"bond_pair (GCN only): induced subgraph = movable atoms ∪ bond (i,j); dedicated sub_convs (not weight-tied "
|
||||
"to full-graph trunk); batched B×T subgraph forward; then MLP on [global pooled || subgraph pooled]."
|
||||
"bond_pair (GCN only): induced subgraph = movable ∪ bond (i,j); subgraph node input = concat(coords, "
|
||||
"full-graph trunk node h); dedicated sub_convs; batched B×T forward; MLP on [global pooled || subgraph pool]."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user