Files
riemann-flow-gnn/examples/toy_pipeline.rs

235 lines
6.7 KiB
Rust

use std::{collections::BTreeMap, marker::PhantomData, path::Path};
use std::fs;
use burn::{backend::Autodiff, tensor::backend::AutodiffBackend};
use riemann_flow_gnn::{
api::{OneSampleToyBatch, OneSampleToyPipeline},
conformer::LigandConformer,
geometry::PoseState,
graph::{Atom, Bond, LigandGraph, TorsionEdge},
viz::export_simulation_gif_only,
};
use serde::Deserialize;
use riemann_flow_gnn::api::burn_training::{
run_burn_training, BurnToyModel, BurnTrainConfig,
};
use safetensors::{
serialize_to_file,
tensor::{Dtype, TensorView},
};
type ExampleBackend = Autodiff<burn::backend::CudaJit>;
#[derive(Debug, Deserialize)]
struct ToyAtom {
index: usize,
atomic_number: u8,
formal_charge: i8,
is_aromatic: bool,
}
#[derive(Debug, Deserialize)]
struct ToyBond {
index: usize,
src: usize,
dst: usize,
bond_order: u8,
is_aromatic: bool,
is_rotatable: bool,
}
#[derive(Debug, Deserialize)]
struct ToyReferenceData {
atoms: Vec<ToyAtom>,
bonds: Vec<ToyBond>,
coords_angstrom: Vec<[f32; 3]>,
torsion_edges: Vec<TorsionEdgeJson>,
}
#[derive(Debug, Deserialize)]
struct TorsionEdgeJson {
bond_index: usize,
atom_left: usize,
atom_right: usize,
rotating_side: Vec<usize>,
}
fn load_toy_reference(path: &str) -> Result<(LigandGraph, LigandConformer), Box<dyn std::error::Error>> {
let raw = fs::read_to_string(path)?;
let parsed: ToyReferenceData = serde_json::from_str(&raw)?;
let atoms = parsed
.atoms
.into_iter()
.map(|a| Atom {
index: a.index,
atomic_number: a.atomic_number,
formal_charge: a.formal_charge,
is_aromatic: a.is_aromatic,
})
.collect::<Vec<_>>();
let bonds = parsed
.bonds
.into_iter()
.map(|b| Bond {
index: b.index,
src: b.src,
dst: b.dst,
bond_order: b.bond_order,
is_aromatic: b.is_aromatic,
is_rotatable: b.is_rotatable,
})
.collect::<Vec<_>>();
let torsion_edges = parsed
.torsion_edges
.into_iter()
.map(|t| TorsionEdge {
bond_index: t.bond_index,
atom_left: t.atom_left,
atom_right: t.atom_right,
rotating_side: t.rotating_side,
})
.collect::<Vec<_>>();
let graph = LigandGraph::new(atoms, bonds, torsion_edges);
let reference = LigandConformer::new(graph.clone(), parsed.coords_angstrom)?;
Ok((graph, reference))
}
struct ExampleBurnModel<B: AutodiffBackend> {
_backend: PhantomData<B>,
}
impl<B: AutodiffBackend> ExampleBurnModel<B> {
fn forward_template(&self, _graph: &riemann_flow_gnn::graph::GraphFeatures) -> Result<(), String> {
// TODO(user): Burn tensor 입력(GraphTensors, PoseTensors) 기반 forward 구현
Ok(())
}
fn loss_template(&self) -> Result<f64, String> {
// TODO(user): translation/rotation/torsion loss 조합 구현
Ok(1.0)
}
fn optimizer_step_template(&mut self, _lr: f64) -> Result<(), String> {
// TODO(user): Burn optimizer step 구현
Ok(())
}
}
impl<B: AutodiffBackend> BurnToyModel<B> for ExampleBurnModel<B> {
fn name(&self) -> &'static str {
"burn-toy-model"
}
fn new(_device: &B::Device) -> Self {
Self {
_backend: PhantomData,
}
}
fn initialize(
&mut self,
_graph: &riemann_flow_gnn::graph::GraphFeatures,
_device: &B::Device,
) -> Result<(), String> {
// TODO(user): Burn 모델 파라미터 초기화
Ok(())
}
fn train_step(
&mut self,
graph: &riemann_flow_gnn::graph::GraphFeatures,
epoch: usize,
lr: f64,
) -> Result<f64, String> {
self.forward_template(graph)?;
let base_loss = self.loss_template()?;
self.optimizer_step_template(lr)?;
Ok(base_loss / (epoch as f64))
}
fn save_model_safetensors(&self, path: &Path) -> Result<(), String> {
// TODO(user): 실제 model state tensor들을 safetensors로 저장
let value_bytes = 1.0f32.to_le_bytes();
let value_view =
TensorView::new(Dtype::F32, vec![1], &value_bytes).map_err(|e| e.to_string())?;
let tensors = BTreeMap::from([("placeholder_weight".to_string(), value_view)]);
serialize_to_file(tensors, &None, path).map_err(|e| e.to_string())?;
Ok(())
}
fn load_model_safetensors(_device: &B::Device, _path: &Path) -> Result<Self, String> {
// TODO(user): safetensors에서 model state 복원
Ok(Self {
_backend: PhantomData,
})
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let (graph, reference) = load_toy_reference("data/toy/aspirin_reference.json")?;
let pose_start = PoseState {
translation: [0.0, 0.0, 0.0],
rotation_quat_xyzw: [0.0, 0.0, 0.0, 1.0],
torsions: vec![0.0],
};
let pose_target = PoseState {
translation: [1.0, 1.5, 0.5],
rotation_quat_xyzw: [0.0, 0.0, 0.38268343, 0.9238795],
torsions: vec![1.2],
};
let pipeline = OneSampleToyPipeline::new(OneSampleToyBatch {
graph,
reference: reference.clone(),
pose_start,
pose_target,
});
let graph_features = pipeline.graph_contract();
let history = {
let device = Default::default();
let mut model = ExampleBurnModel::<ExampleBackend>::new(&device);
let train_config = BurnTrainConfig {
epochs: 20,
learning_rate: 1e-3,
checkpoint_every: 5,
checkpoint_dir: "outputs/checkpoints".to_string(),
};
run_burn_training::<ExampleBackend, _>(
&mut model,
&graph_features,
&device,
&train_config,
)?
};
// 3) simulation video export (gif only)
export_simulation_gif_only(
&pipeline.batch.reference,
&pipeline.batch.pose_start,
&pipeline.batch.pose_target,
48,
"outputs/simulation.gif",
24,
)?;
println!(
"contract: nodes={}, edges={}, torsions={}, node_dim={}, edge_dim={}",
graph_features.num_nodes,
graph_features.num_edges,
graph_features.num_torsion_edges,
graph_features.spec.node_feat_dim,
graph_features.spec.edge_feat_dim
);
println!(
"burn training finished: epochs={}, last_loss={:.6}",
history.len(),
history.last().map(|m| m.loss).unwrap_or(0.0)
);
println!("saved checkpoints: outputs/checkpoints/*.safetensors");
println!("saved gif only: outputs/simulation.gif");
println!("target reference: data/toy/aspirin_reference.json");
Ok(())
}