237 lines
6.9 KiB
Rust
237 lines
6.9 KiB
Rust
use std::{collections::BTreeMap, marker::PhantomData, path::Path};
|
|
use std::fs;
|
|
|
|
use burn::{backend::Autodiff, tensor::backend::AutodiffBackend};
|
|
use riemann_flow_gnn::{
|
|
api::{OneSampleToyBatch, OneSampleToyPipeline},
|
|
conformer::LigandConformer,
|
|
geometry::PoseState,
|
|
graph::{Atom, Bond, LigandGraph, TorsionEdge},
|
|
viz::{export_simulation_video, plot_static_png},
|
|
};
|
|
use serde::Deserialize;
|
|
use riemann_flow_gnn::api::burn_training::{
|
|
run_burn_training, BurnToyModel, BurnTrainConfig,
|
|
};
|
|
use safetensors::{
|
|
serialize_to_file,
|
|
tensor::{Dtype, TensorView},
|
|
};
|
|
|
|
type ExampleBackend = Autodiff<burn::backend::CudaJit>;
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ToyAtom {
|
|
index: usize,
|
|
atomic_number: u8,
|
|
formal_charge: i8,
|
|
is_aromatic: bool,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ToyBond {
|
|
index: usize,
|
|
src: usize,
|
|
dst: usize,
|
|
bond_order: u8,
|
|
is_aromatic: bool,
|
|
is_rotatable: bool,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ToyReferenceData {
|
|
atoms: Vec<ToyAtom>,
|
|
bonds: Vec<ToyBond>,
|
|
coords_angstrom: Vec<[f32; 3]>,
|
|
torsion_edges: Vec<TorsionEdgeJson>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct TorsionEdgeJson {
|
|
bond_index: usize,
|
|
atom_left: usize,
|
|
atom_right: usize,
|
|
rotating_side: Vec<usize>,
|
|
}
|
|
|
|
fn load_toy_reference(path: &str) -> Result<(LigandGraph, LigandConformer), Box<dyn std::error::Error>> {
|
|
let raw = fs::read_to_string(path)?;
|
|
let parsed: ToyReferenceData = serde_json::from_str(&raw)?;
|
|
|
|
let atoms = parsed
|
|
.atoms
|
|
.into_iter()
|
|
.map(|a| Atom {
|
|
index: a.index,
|
|
atomic_number: a.atomic_number,
|
|
formal_charge: a.formal_charge,
|
|
is_aromatic: a.is_aromatic,
|
|
})
|
|
.collect::<Vec<_>>();
|
|
let bonds = parsed
|
|
.bonds
|
|
.into_iter()
|
|
.map(|b| Bond {
|
|
index: b.index,
|
|
src: b.src,
|
|
dst: b.dst,
|
|
bond_order: b.bond_order,
|
|
is_aromatic: b.is_aromatic,
|
|
is_rotatable: b.is_rotatable,
|
|
})
|
|
.collect::<Vec<_>>();
|
|
let torsion_edges = parsed
|
|
.torsion_edges
|
|
.into_iter()
|
|
.map(|t| TorsionEdge {
|
|
bond_index: t.bond_index,
|
|
atom_left: t.atom_left,
|
|
atom_right: t.atom_right,
|
|
rotating_side: t.rotating_side,
|
|
})
|
|
.collect::<Vec<_>>();
|
|
let graph = LigandGraph::new(atoms, bonds, torsion_edges);
|
|
let reference = LigandConformer::new(graph.clone(), parsed.coords_angstrom)?;
|
|
Ok((graph, reference))
|
|
}
|
|
|
|
struct ExampleBurnModel<B: AutodiffBackend> {
|
|
_backend: PhantomData<B>,
|
|
}
|
|
|
|
impl<B: AutodiffBackend> ExampleBurnModel<B> {
|
|
fn forward_template(&self, _graph: &riemann_flow_gnn::graph::GraphFeatures) -> Result<(), String> {
|
|
// TODO(user): Burn tensor 입력(GraphTensors, PoseTensors) 기반 forward 구현
|
|
Ok(())
|
|
}
|
|
|
|
fn loss_template(&self) -> Result<f64, String> {
|
|
// TODO(user): translation/rotation/torsion loss 조합 구현
|
|
Ok(1.0)
|
|
}
|
|
|
|
fn optimizer_step_template(&mut self, _lr: f64) -> Result<(), String> {
|
|
// TODO(user): Burn optimizer step 구현
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl<B: AutodiffBackend> BurnToyModel<B> for ExampleBurnModel<B> {
|
|
fn name(&self) -> &'static str {
|
|
"burn-toy-model"
|
|
}
|
|
|
|
fn new(_device: &B::Device) -> Self {
|
|
Self {
|
|
_backend: PhantomData,
|
|
}
|
|
}
|
|
|
|
fn initialize(
|
|
&mut self,
|
|
_graph: &riemann_flow_gnn::graph::GraphFeatures,
|
|
_device: &B::Device,
|
|
) -> Result<(), String> {
|
|
// TODO(user): Burn 모델 파라미터 초기화
|
|
Ok(())
|
|
}
|
|
|
|
fn train_step(
|
|
&mut self,
|
|
graph: &riemann_flow_gnn::graph::GraphFeatures,
|
|
epoch: usize,
|
|
lr: f64,
|
|
) -> Result<f64, String> {
|
|
self.forward_template(graph)?;
|
|
let base_loss = self.loss_template()?;
|
|
self.optimizer_step_template(lr)?;
|
|
Ok(base_loss / (epoch as f64))
|
|
}
|
|
|
|
fn save_model_safetensors(&self, path: &Path) -> Result<(), String> {
|
|
// TODO(user): 실제 model state tensor들을 safetensors로 저장
|
|
let value_bytes = 1.0f32.to_le_bytes();
|
|
let value_view =
|
|
TensorView::new(Dtype::F32, vec![1], &value_bytes).map_err(|e| e.to_string())?;
|
|
let tensors = BTreeMap::from([("placeholder_weight".to_string(), value_view)]);
|
|
serialize_to_file(tensors, &None, path).map_err(|e| e.to_string())?;
|
|
Ok(())
|
|
}
|
|
|
|
fn load_model_safetensors(_device: &B::Device, _path: &Path) -> Result<Self, String> {
|
|
// TODO(user): safetensors에서 model state 복원
|
|
Ok(Self {
|
|
_backend: PhantomData,
|
|
})
|
|
}
|
|
}
|
|
|
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
let (graph, reference) = load_toy_reference("data/toy/aspirin_reference.json")?;
|
|
|
|
let pose_start = PoseState {
|
|
translation: [0.0, 0.0, 0.0],
|
|
rotation_quat_xyzw: [0.0, 0.0, 0.0, 1.0],
|
|
torsions: vec![0.0],
|
|
};
|
|
let pose_target = PoseState {
|
|
translation: [1.0, 1.5, 0.5],
|
|
rotation_quat_xyzw: [0.0, 0.0, 0.38268343, 0.9238795],
|
|
torsions: vec![1.2],
|
|
};
|
|
|
|
let pipeline = OneSampleToyPipeline::new(OneSampleToyBatch {
|
|
graph,
|
|
reference: reference.clone(),
|
|
pose_start,
|
|
pose_target,
|
|
});
|
|
let graph_features = pipeline.graph_contract();
|
|
|
|
let history = {
|
|
let device = Default::default();
|
|
let mut model = ExampleBurnModel::<ExampleBackend>::new(&device);
|
|
let train_config = BurnTrainConfig {
|
|
epochs: 20,
|
|
learning_rate: 1e-3,
|
|
checkpoint_every: 5,
|
|
checkpoint_dir: "outputs/checkpoints".to_string(),
|
|
};
|
|
run_burn_training::<ExampleBackend, _>(
|
|
&mut model,
|
|
&graph_features,
|
|
&device,
|
|
&train_config,
|
|
)?
|
|
};
|
|
|
|
// 3) simulation video export
|
|
let trajectory = export_simulation_video(
|
|
&pipeline.batch.reference,
|
|
&pipeline.batch.pose_start,
|
|
&pipeline.batch.pose_target,
|
|
48,
|
|
"outputs",
|
|
24,
|
|
)?;
|
|
plot_static_png(&trajectory[0], "outputs/static_start.png")?;
|
|
plot_static_png(trajectory.last().ok_or("empty trajectory")?, "outputs/static_end.png")?;
|
|
println!(
|
|
"contract: nodes={}, edges={}, torsions={}, node_dim={}, edge_dim={}",
|
|
graph_features.num_nodes,
|
|
graph_features.num_edges,
|
|
graph_features.num_torsion_edges,
|
|
graph_features.spec.node_feat_dim,
|
|
graph_features.spec.edge_feat_dim
|
|
);
|
|
println!(
|
|
"burn training finished: epochs={}, last_loss={:.6}",
|
|
history.len(),
|
|
history.last().map(|m| m.loss).unwrap_or(0.0)
|
|
);
|
|
println!("saved checkpoints: outputs/checkpoints/*.safetensors");
|
|
println!("saved video: outputs/simulation.mp4");
|
|
println!("target reference: data/toy/aspirin_reference.json");
|
|
Ok(())
|
|
}
|