Improve RMSD with weighted flow objective.

Increase rotational and torsional emphasis in normalized velocity loss with gradient clipping, yielding a better 100-run final-time RMSD while preserving random-time flow-matching training.

Made-with: Cursor
This commit is contained in:
demian3b
2026-04-16 17:29:43 +09:00
parent e7274cb680
commit 3ddae9d815
4 changed files with 23 additions and 11 deletions

View File

@@ -1,9 +1,9 @@
{ {
"best_mean_rmsd_100": 2.519821240901947, "best_mean_rmsd_100": 2.5055564713478087,
"num_runs": 100, "num_runs": 100,
"timestamp_utc": "2026-04-16T08:21:04.567731+00:00", "timestamp_utc": "2026-04-16T08:29:07.275746+00:00",
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --model-type gcn --epochs 120 --batch-size 96 --num-workers 8 --prefetch-factor 8 --hidden 512 --gcn-layers 8 --accum 2 --time-power 1.0 --eval-runs 100 --out-dir /tmp/ai_rfm_checkpoint_artifacts", "command": "train.py --sdf /data/demian_dev/toy/sample.sdf --model-type gcn --epochs 140 --batch-size 96 --num-workers 8 --prefetch-factor 8 --hidden 512 --gcn-layers 8 --accum 2 --time-power 1.0 --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.8 --eval-runs 100 --out-dir /tmp/ai_rfm_weighted_a",
"notes": "Auto-updated by pre-commit from reports/latest_eval.json.", "notes": "Auto-updated by pre-commit from reports/latest_eval.json.",
"updated_by_commit": "pending", "updated_by_commit": "pending",
"best_train_mse": 2.259780168533325 "best_train_mse": 5.101677894592285
} }

View File

@@ -49,3 +49,4 @@ This repository is intentionally pinned to CUDA 12.6 PyTorch wheels and matching
- 2026-04-16: Added pre-commit artifact refresh: on best update it now stages `BEST_PRACTICE.json`, `artifacts/best_model.pt`, and regenerates 6 trajectory visualizations in `reports/trajectories/`. - 2026-04-16: Added pre-commit artifact refresh: on best update it now stages `BEST_PRACTICE.json`, `artifacts/best_model.pt`, and regenerates 6 trajectory visualizations in `reports/trajectories/`.
- 2026-04-16: Enforced random-time flow-matching rule (no fixed training time), saved best checkpoint to git-tracked artifact path, and improved metric to `mean_rmsd_100=2.519821` with `gcn hidden=512 layers=8 batch=96`. - 2026-04-16: Enforced random-time flow-matching rule (no fixed training time), saved best checkpoint to git-tracked artifact path, and improved metric to `mean_rmsd_100=2.519821` with `gcn hidden=512 layers=8 batch=96`.
- 2026-04-16: Added a general multi-layer diagnosis principle to `GUIDELINES.md` so experiments are judged with quantitative + qualitative + structural evidence, not metric-only optimization. - 2026-04-16: Added a general multi-layer diagnosis principle to `GUIDELINES.md` so experiments are judged with quantitative + qualitative + structural evidence, not metric-only optimization.
- 2026-04-16: Tried weighted objective to counter weak rotation/torsion motion (`w_center=0.8, w_omega=2.0, w_torsion=3.0, grad_clip=0.8`) and improved to `mean_rmsd_100=2.505556`.

View File

@@ -1,10 +1,10 @@
{ {
"mean_rmsd_100": 2.519821240901947, "mean_rmsd_100": 2.5055564713478087,
"num_runs": 100, "num_runs": 100,
"timestamp_utc": "2026-04-16T08:21:04.567731+00:00", "timestamp_utc": "2026-04-16T08:29:07.275746+00:00",
"command": "train.py --sdf /data/demian_dev/toy/sample.sdf --model-type gcn --epochs 120 --batch-size 96 --num-workers 8 --prefetch-factor 8 --hidden 512 --gcn-layers 8 --accum 2 --time-power 1.0 --eval-runs 100 --out-dir /tmp/ai_rfm_checkpoint_artifacts", "command": "train.py --sdf /data/demian_dev/toy/sample.sdf --model-type gcn --epochs 140 --batch-size 96 --num-workers 8 --prefetch-factor 8 --hidden 512 --gcn-layers 8 --accum 2 --time-power 1.0 --weight-center 0.8 --weight-omega 2.0 --weight-torsion 3.0 --grad-clip 0.8 --eval-runs 100 --out-dir /tmp/ai_rfm_weighted_a",
"notes": "Final test metric from 100 random-initialized rollouts to time=1.", "notes": "Final test metric from 100 random-initialized rollouts to time=1.",
"best_train_mse": 2.259780168533325, "best_train_mse": 5.101677894592285,
"model_source": "best_train_checkpoint", "model_source": "best_train_checkpoint",
"checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/best_model.pt" "checkpoint_path": "/data/demian_dev/toy/ai_rfm/artifacts/best_model.pt"
} }

View File

@@ -618,6 +618,10 @@ def main() -> int:
parser.add_argument("--epochs", type=int, default=2000) parser.add_argument("--epochs", type=int, default=2000)
parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--weight-center", type=float, default=1.0, help="Loss weight for translation velocity.")
parser.add_argument("--weight-omega", type=float, default=2.0, help="Loss weight for rotation velocity.")
parser.add_argument("--weight-torsion", type=float, default=4.0, help="Loss weight for torsion velocity.")
parser.add_argument("--grad-clip", type=float, default=1.0, help="Gradient clipping max norm.")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hidden", type=int, default=128, help="GCN hidden dim (larger => more GPU work)") parser.add_argument("--hidden", type=int, default=128, help="GCN hidden dim (larger => more GPU work)")
parser.add_argument( parser.add_argument(
@@ -741,6 +745,11 @@ def main() -> int:
f"layers={args.gcn_layers} params={n_params:,}" f"layers={args.gcn_layers} params={n_params:,}"
) )
print(f"time sampling power={args.time_power}") print(f"time sampling power={args.time_power}")
print(
"loss weights: "
f"center={args.weight_center} omega={args.weight_omega} torsion={args.weight_torsion}; "
f"grad_clip={args.grad_clip}"
)
nw = args.num_workers nw = args.num_workers
if nw < 0: if nw < 0:
@@ -811,13 +820,15 @@ def main() -> int:
scale_w = (tw * tw).mean().clamp(min=eps) scale_w = (tw * tw).mean().clamp(min=eps)
scale_t = (tt * tt).mean().clamp(min=eps) scale_t = (tt * tt).mean().clamp(min=eps)
mse = ( mse = (
F.mse_loss(pred.center, tc) / scale_c args.weight_center * (F.mse_loss(pred.center, tc) / scale_c)
+ F.mse_loss(pred.rot_omega, tw) / scale_w + args.weight_omega * (F.mse_loss(pred.rot_omega, tw) / scale_w)
+ F.mse_loss(pred.torsion, tt) / scale_t + args.weight_torsion * (F.mse_loss(pred.torsion, tt) / scale_t)
) )
(mse / accum).backward() (mse / accum).backward()
sum_mse = sum_mse + mse.detach() sum_mse = sum_mse + mse.detach()
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step() optimizer.step()
avg_mse = sum_mse / accum avg_mse = sum_mse / accum