Update dependencies and refactor visualization to export simulation as GIF only. Added new packages for GIF and TIFF support, and modified the toy pipeline to utilize the new GIF export functionality.

This commit is contained in:
demian3b
2026-04-16 00:12:12 +09:00
parent c637a63953
commit db105a77e7
4 changed files with 70 additions and 70 deletions

33
Cargo.lock generated
View File

@@ -1607,6 +1607,16 @@ dependencies = [
"wasip3",
]
[[package]]
name = "gif"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ae047235e33e2829703574b54fdec96bfbad892062d97fed2f76022287de61b"
dependencies = [
"color_quant",
"weezl",
]
[[package]]
name = "gif"
version = "0.14.2"
@@ -1994,9 +2004,13 @@ dependencies = [
"bytemuck",
"byteorder",
"color_quant",
"exr",
"gif 0.13.3",
"jpeg-decoder",
"num-traits",
"png 0.17.16",
"qoi",
"tiff 0.9.1",
]
[[package]]
@@ -2009,7 +2023,7 @@ dependencies = [
"byteorder-lite",
"color_quant",
"exr",
"gif",
"gif 0.14.2",
"image-webp",
"moxcms",
"num-traits",
@@ -2018,7 +2032,7 @@ dependencies = [
"ravif",
"rayon",
"rgb",
"tiff",
"tiff 0.11.3",
"zune-core",
"zune-jpeg",
]
@@ -2147,6 +2161,9 @@ name = "jpeg-decoder"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00810f1d8b74be64b13dbf3db89ac67740615d6c891f0e7b6179326533011a07"
dependencies = [
"rayon",
]
[[package]]
name = "js-sys"
@@ -3272,6 +3289,7 @@ version = "0.1.0"
dependencies = [
"approx",
"burn",
"image 0.24.9",
"plotters",
"safetensors 0.4.5",
"serde",
@@ -3884,6 +3902,17 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "tiff"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba1310fcea54c6a9a4fd1aad794ecc02c31682f6bfbecdf460bf19533eed1e3e"
dependencies = [
"flate2",
"jpeg-decoder",
"weezl",
]
[[package]]
name = "tiff"
version = "0.11.3"

View File

@@ -5,6 +5,7 @@ edition = "2021"
[dependencies]
plotters = { version = "0.3", default-features = false, features = ["bitmap_backend", "bitmap_encoder"] }
image = "0.24"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "1"

View File

@@ -7,7 +7,7 @@ use riemann_flow_gnn::{
conformer::LigandConformer,
geometry::PoseState,
graph::{Atom, Bond, LigandGraph, TorsionEdge},
viz::{export_simulation_video, plot_static_png},
viz::export_simulation_gif_only,
};
use serde::Deserialize;
use riemann_flow_gnn::api::burn_training::{
@@ -205,17 +205,15 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
)?
};
// 3) simulation video export
let trajectory = export_simulation_video(
// 3) simulation video export (gif only)
export_simulation_gif_only(
&pipeline.batch.reference,
&pipeline.batch.pose_start,
&pipeline.batch.pose_target,
48,
"outputs",
"outputs/simulation.gif",
24,
)?;
plot_static_png(&trajectory[0], "outputs/static_start.png")?;
plot_static_png(trajectory.last().ok_or("empty trajectory")?, "outputs/static_end.png")?;
println!(
"contract: nodes={}, edges={}, torsions={}, node_dim={}, edge_dim={}",
graph_features.num_nodes,
@@ -230,7 +228,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
history.last().map(|m| m.loss).unwrap_or(0.0)
);
println!("saved checkpoints: outputs/checkpoints/*.safetensors");
println!("saved video: outputs/simulation.mp4");
println!("saved gif only: outputs/simulation.gif");
println!("target reference: data/toy/aspirin_reference.json");
Ok(())
}

View File

@@ -1,11 +1,10 @@
use std::{
fs::{self, File},
io::Write,
path::{Path, PathBuf},
process::Command,
path::Path,
};
use plotters::prelude::*;
use image::{codecs::gif::GifEncoder, Delay, Frame as ImageFrame};
use crate::{
conformer::LigandConformer,
@@ -109,69 +108,42 @@ pub fn plot_trajectory_png_frames<P: AsRef<Path>>(
Ok(())
}
pub fn export_xyz_frames<P: AsRef<Path>>(
frames: &[LigandConformer],
out_dir: P,
) -> Result<(), std::io::Error> {
fs::create_dir_all(out_dir.as_ref())?;
for (idx, frame) in frames.iter().enumerate() {
let out = out_dir.as_ref().join(format!("frame_{idx:04}.xyz"));
let mut f = File::create(out)?;
writeln!(f, "{}", frame.coords.len())?;
writeln!(f, "riemann-flow-gnn frame {idx}")?;
for (atom, c) in frame.graph.atoms.iter().zip(&frame.coords) {
writeln!(f, "{} {:.6} {:.6} {:.6}", atom.atomic_number, c[0], c[1], c[2])?;
}
}
Ok(())
}
pub fn frames_to_mp4<P: AsRef<Path>>(
frames_dir: P,
output_mp4: P,
fps: u32,
) -> Result<bool, Box<dyn std::error::Error>> {
let input_pattern: PathBuf = frames_dir.as_ref().join("frame_%04d.png");
let status = match Command::new("ffmpeg")
.args([
"-y",
"-framerate",
&fps.to_string(),
"-i",
input_pattern.to_string_lossy().as_ref(),
"-pix_fmt",
"yuv420p",
output_mp4.as_ref().to_string_lossy().as_ref(),
])
.status()
{
Ok(status) => status,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(false),
Err(e) => return Err(Box::new(e)),
};
if !status.success() {
return Err("ffmpeg command failed".into());
}
Ok(true)
}
pub fn export_simulation_video<P: AsRef<Path>>(
pub fn export_simulation_gif_only<P: AsRef<Path>>(
reference: &LigandConformer,
pose_start: &PoseState,
pose_target: &PoseState,
num_steps: usize,
out_dir: P,
output_gif: P,
fps: u32,
) -> Result<Vec<LigandConformer>, Box<dyn std::error::Error>> {
fs::create_dir_all(out_dir.as_ref())?;
let frames_png_dir = out_dir.as_ref().join("frames_png");
let frames_xyz_dir = out_dir.as_ref().join("frames_xyz");
let video_path = out_dir.as_ref().join("simulation.mp4");
) -> Result<(), Box<dyn std::error::Error>> {
let output_gif = output_gif.as_ref();
if let Some(parent) = output_gif.parent() {
fs::create_dir_all(parent)?;
}
let temp_frames_dir = output_gif
.parent()
.unwrap_or_else(|| Path::new("."))
.join(".tmp_gif_frames");
if temp_frames_dir.exists() {
fs::remove_dir_all(&temp_frames_dir)?;
}
fs::create_dir_all(&temp_frames_dir)?;
let frames = simulate_trajectory(reference, pose_start, pose_target, num_steps)?;
plot_trajectory_png_frames(&frames, &frames_png_dir)?;
export_xyz_frames(&frames, &frames_xyz_dir)?;
let _video_written = frames_to_mp4(&frames_png_dir, &video_path, fps)?;
Ok(frames)
plot_trajectory_png_frames(&frames, &temp_frames_dir)?;
let file = File::create(output_gif)?;
let mut encoder = GifEncoder::new(file);
let delay_ms = (1000 / fps.max(1)) as u32;
for i in 0..frames.len() {
let frame_path = temp_frames_dir.join(format!("frame_{i:04}.png"));
let rgba = image::open(frame_path)?.to_rgba8();
let frame =
ImageFrame::from_parts(rgba, 0, 0, Delay::from_numer_denom_ms(delay_ms, 1));
encoder.encode_frame(frame)?;
}
fs::remove_dir_all(&temp_frames_dir)?;
Ok(())
}