Update dependencies and refactor visualization to export simulation as GIF only. Added new packages for GIF and TIFF support, and modified the toy pipeline to utilize the new GIF export functionality.
This commit is contained in:
33
Cargo.lock
generated
33
Cargo.lock
generated
@@ -1607,6 +1607,16 @@ dependencies = [
|
||||
"wasip3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gif"
|
||||
version = "0.13.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ae047235e33e2829703574b54fdec96bfbad892062d97fed2f76022287de61b"
|
||||
dependencies = [
|
||||
"color_quant",
|
||||
"weezl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gif"
|
||||
version = "0.14.2"
|
||||
@@ -1994,9 +2004,13 @@ dependencies = [
|
||||
"bytemuck",
|
||||
"byteorder",
|
||||
"color_quant",
|
||||
"exr",
|
||||
"gif 0.13.3",
|
||||
"jpeg-decoder",
|
||||
"num-traits",
|
||||
"png 0.17.16",
|
||||
"qoi",
|
||||
"tiff 0.9.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2009,7 +2023,7 @@ dependencies = [
|
||||
"byteorder-lite",
|
||||
"color_quant",
|
||||
"exr",
|
||||
"gif",
|
||||
"gif 0.14.2",
|
||||
"image-webp",
|
||||
"moxcms",
|
||||
"num-traits",
|
||||
@@ -2018,7 +2032,7 @@ dependencies = [
|
||||
"ravif",
|
||||
"rayon",
|
||||
"rgb",
|
||||
"tiff",
|
||||
"tiff 0.11.3",
|
||||
"zune-core",
|
||||
"zune-jpeg",
|
||||
]
|
||||
@@ -2147,6 +2161,9 @@ name = "jpeg-decoder"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "00810f1d8b74be64b13dbf3db89ac67740615d6c891f0e7b6179326533011a07"
|
||||
dependencies = [
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
@@ -3272,6 +3289,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"burn",
|
||||
"image 0.24.9",
|
||||
"plotters",
|
||||
"safetensors 0.4.5",
|
||||
"serde",
|
||||
@@ -3884,6 +3902,17 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiff"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba1310fcea54c6a9a4fd1aad794ecc02c31682f6bfbecdf460bf19533eed1e3e"
|
||||
dependencies = [
|
||||
"flate2",
|
||||
"jpeg-decoder",
|
||||
"weezl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiff"
|
||||
version = "0.11.3"
|
||||
|
||||
@@ -5,6 +5,7 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
plotters = { version = "0.3", default-features = false, features = ["bitmap_backend", "bitmap_encoder"] }
|
||||
image = "0.24"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
thiserror = "1"
|
||||
|
||||
@@ -7,7 +7,7 @@ use riemann_flow_gnn::{
|
||||
conformer::LigandConformer,
|
||||
geometry::PoseState,
|
||||
graph::{Atom, Bond, LigandGraph, TorsionEdge},
|
||||
viz::{export_simulation_video, plot_static_png},
|
||||
viz::export_simulation_gif_only,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use riemann_flow_gnn::api::burn_training::{
|
||||
@@ -205,17 +205,15 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
)?
|
||||
};
|
||||
|
||||
// 3) simulation video export
|
||||
let trajectory = export_simulation_video(
|
||||
// 3) simulation video export (gif only)
|
||||
export_simulation_gif_only(
|
||||
&pipeline.batch.reference,
|
||||
&pipeline.batch.pose_start,
|
||||
&pipeline.batch.pose_target,
|
||||
48,
|
||||
"outputs",
|
||||
"outputs/simulation.gif",
|
||||
24,
|
||||
)?;
|
||||
plot_static_png(&trajectory[0], "outputs/static_start.png")?;
|
||||
plot_static_png(trajectory.last().ok_or("empty trajectory")?, "outputs/static_end.png")?;
|
||||
println!(
|
||||
"contract: nodes={}, edges={}, torsions={}, node_dim={}, edge_dim={}",
|
||||
graph_features.num_nodes,
|
||||
@@ -230,7 +228,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
history.last().map(|m| m.loss).unwrap_or(0.0)
|
||||
);
|
||||
println!("saved checkpoints: outputs/checkpoints/*.safetensors");
|
||||
println!("saved video: outputs/simulation.mp4");
|
||||
println!("saved gif only: outputs/simulation.gif");
|
||||
println!("target reference: data/toy/aspirin_reference.json");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
94
src/viz.rs
94
src/viz.rs
@@ -1,11 +1,10 @@
|
||||
use std::{
|
||||
fs::{self, File},
|
||||
io::Write,
|
||||
path::{Path, PathBuf},
|
||||
process::Command,
|
||||
path::Path,
|
||||
};
|
||||
|
||||
use plotters::prelude::*;
|
||||
use image::{codecs::gif::GifEncoder, Delay, Frame as ImageFrame};
|
||||
|
||||
use crate::{
|
||||
conformer::LigandConformer,
|
||||
@@ -109,69 +108,42 @@ pub fn plot_trajectory_png_frames<P: AsRef<Path>>(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn export_xyz_frames<P: AsRef<Path>>(
|
||||
frames: &[LigandConformer],
|
||||
out_dir: P,
|
||||
) -> Result<(), std::io::Error> {
|
||||
fs::create_dir_all(out_dir.as_ref())?;
|
||||
for (idx, frame) in frames.iter().enumerate() {
|
||||
let out = out_dir.as_ref().join(format!("frame_{idx:04}.xyz"));
|
||||
let mut f = File::create(out)?;
|
||||
writeln!(f, "{}", frame.coords.len())?;
|
||||
writeln!(f, "riemann-flow-gnn frame {idx}")?;
|
||||
for (atom, c) in frame.graph.atoms.iter().zip(&frame.coords) {
|
||||
writeln!(f, "{} {:.6} {:.6} {:.6}", atom.atomic_number, c[0], c[1], c[2])?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn frames_to_mp4<P: AsRef<Path>>(
|
||||
frames_dir: P,
|
||||
output_mp4: P,
|
||||
fps: u32,
|
||||
) -> Result<bool, Box<dyn std::error::Error>> {
|
||||
let input_pattern: PathBuf = frames_dir.as_ref().join("frame_%04d.png");
|
||||
let status = match Command::new("ffmpeg")
|
||||
.args([
|
||||
"-y",
|
||||
"-framerate",
|
||||
&fps.to_string(),
|
||||
"-i",
|
||||
input_pattern.to_string_lossy().as_ref(),
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
output_mp4.as_ref().to_string_lossy().as_ref(),
|
||||
])
|
||||
.status()
|
||||
{
|
||||
Ok(status) => status,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(false),
|
||||
Err(e) => return Err(Box::new(e)),
|
||||
};
|
||||
|
||||
if !status.success() {
|
||||
return Err("ffmpeg command failed".into());
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
pub fn export_simulation_video<P: AsRef<Path>>(
|
||||
pub fn export_simulation_gif_only<P: AsRef<Path>>(
|
||||
reference: &LigandConformer,
|
||||
pose_start: &PoseState,
|
||||
pose_target: &PoseState,
|
||||
num_steps: usize,
|
||||
out_dir: P,
|
||||
output_gif: P,
|
||||
fps: u32,
|
||||
) -> Result<Vec<LigandConformer>, Box<dyn std::error::Error>> {
|
||||
fs::create_dir_all(out_dir.as_ref())?;
|
||||
let frames_png_dir = out_dir.as_ref().join("frames_png");
|
||||
let frames_xyz_dir = out_dir.as_ref().join("frames_xyz");
|
||||
let video_path = out_dir.as_ref().join("simulation.mp4");
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let output_gif = output_gif.as_ref();
|
||||
if let Some(parent) = output_gif.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let temp_frames_dir = output_gif
|
||||
.parent()
|
||||
.unwrap_or_else(|| Path::new("."))
|
||||
.join(".tmp_gif_frames");
|
||||
if temp_frames_dir.exists() {
|
||||
fs::remove_dir_all(&temp_frames_dir)?;
|
||||
}
|
||||
fs::create_dir_all(&temp_frames_dir)?;
|
||||
|
||||
let frames = simulate_trajectory(reference, pose_start, pose_target, num_steps)?;
|
||||
plot_trajectory_png_frames(&frames, &frames_png_dir)?;
|
||||
export_xyz_frames(&frames, &frames_xyz_dir)?;
|
||||
let _video_written = frames_to_mp4(&frames_png_dir, &video_path, fps)?;
|
||||
Ok(frames)
|
||||
plot_trajectory_png_frames(&frames, &temp_frames_dir)?;
|
||||
|
||||
let file = File::create(output_gif)?;
|
||||
let mut encoder = GifEncoder::new(file);
|
||||
let delay_ms = (1000 / fps.max(1)) as u32;
|
||||
for i in 0..frames.len() {
|
||||
let frame_path = temp_frames_dir.join(format!("frame_{i:04}.png"));
|
||||
let rgba = image::open(frame_path)?.to_rgba8();
|
||||
let frame =
|
||||
ImageFrame::from_parts(rgba, 0, 0, Delay::from_numer_denom_ms(delay_ms, 1));
|
||||
encoder.encode_frame(frame)?;
|
||||
}
|
||||
|
||||
fs::remove_dir_all(&temp_frames_dir)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user