Update project configuration and documentation for GPU support; add GPU environment check script and CUDA runtime validation

This commit is contained in:
demian3b
2026-04-15 23:48:51 +09:00
parent 9827bd2491
commit c637a63953
8 changed files with 63 additions and 29 deletions

2
.gitignore vendored
View File

@@ -1,3 +1,5 @@
outputs
# ---> Rust # ---> Rust
# Generated by Cargo # Generated by Cargo
# will have compiled files and executables # will have compiled files and executables

View File

@@ -3,17 +3,13 @@ name = "riemann-flow-gnn"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
[features]
default = []
burn = ["dep:burn", "dep:safetensors"]
[dependencies] [dependencies]
plotters = { version = "0.3", default-features = false, features = ["bitmap_backend", "bitmap_encoder"] } plotters = { version = "0.3", default-features = false, features = ["bitmap_backend", "bitmap_encoder"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"
thiserror = "1" thiserror = "1"
burn = { version = "0.14", optional = true, features = ["autodiff", "ndarray"] } burn = { version = "0.14", features = ["autodiff", "cuda-jit"] }
safetensors = { version = "0.4", optional = true } safetensors = { version = "0.4" }
[dev-dependencies] [dev-dependencies]
approx = "0.5" approx = "0.5"

View File

@@ -13,6 +13,20 @@ Riemannian Flow matching with GNN implementation in rust/burn
cargo run --example toy_pipeline cargo run --example toy_pipeline
``` ```
### NVIDIA GPU 환경 점검
```bash
bash scripts/gpu_env_check.sh
```
점검 항목:
- `nvidia-smi` 드라이버/디바이스 인식
- `nvcc --version` CUDA toolkit 인식
- `scripts/gpu_smoke_test.py` CUDA matmul 런타임 실행
- `cargo check --example toy_pipeline` Burn 빌드 스모크 테스트
예제 학습 스켈레톤은 **기본으로 Burn CUDA JIT 백엔드(GPU)** 를 사용합니다. NVIDIA 드라이버와 CUDA 개발 환경(`nvcc` 등)이 필요합니다.
예제는 파이프라인 계약 확인용으로 동작합니다. 예제는 파이프라인 계약 확인용으로 동작합니다.
### 현재 포함 범위 ### 현재 포함 범위
@@ -20,7 +34,7 @@ cargo run --example toy_pipeline
- graph/domain struct (`Atom`, `Bond`, `TorsionEdge`, `LigandGraph`) - graph/domain struct (`Atom`, `Bond`, `TorsionEdge`, `LigandGraph`)
- one-sample entry API (`api::OneSampleToyPipeline`) - one-sample entry API (`api::OneSampleToyPipeline`)
- graph/pose/simulator 최소 계약 - graph/pose/simulator 최소 계약
- Burn tensor shape 중심 adapter skeleton (`--features burn`) - Burn tensor shape 중심 adapter skeleton (기본 활성)
- TODO 기반 구현 포인트 주석 - TODO 기반 구현 포인트 주석
### 이후 단계 권장 ### 이후 단계 권장

View File

@@ -1,8 +1,6 @@
#[cfg(feature = "burn")]
use std::{collections::BTreeMap, marker::PhantomData, path::Path}; use std::{collections::BTreeMap, marker::PhantomData, path::Path};
use std::fs; use std::fs;
#[cfg(feature = "burn")]
use burn::{backend::Autodiff, tensor::backend::AutodiffBackend}; use burn::{backend::Autodiff, tensor::backend::AutodiffBackend};
use riemann_flow_gnn::{ use riemann_flow_gnn::{
api::{OneSampleToyBatch, OneSampleToyPipeline}, api::{OneSampleToyBatch, OneSampleToyPipeline},
@@ -12,18 +10,15 @@ use riemann_flow_gnn::{
viz::{export_simulation_video, plot_static_png}, viz::{export_simulation_video, plot_static_png},
}; };
use serde::Deserialize; use serde::Deserialize;
#[cfg(feature = "burn")]
use riemann_flow_gnn::api::burn_training::{ use riemann_flow_gnn::api::burn_training::{
run_burn_training, BurnToyModel, BurnTrainConfig, run_burn_training, BurnToyModel, BurnTrainConfig,
}; };
#[cfg(feature = "burn")]
use safetensors::{ use safetensors::{
serialize_to_file, serialize_to_file,
tensor::{Dtype, TensorView}, tensor::{Dtype, TensorView},
}; };
#[cfg(feature = "burn")] type ExampleBackend = Autodiff<burn::backend::CudaJit>;
type ExampleBackend = Autodiff<burn::backend::NdArray>;
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct ToyAtom { struct ToyAtom {
@@ -100,12 +95,10 @@ fn load_toy_reference(path: &str) -> Result<(LigandGraph, LigandConformer), Box<
Ok((graph, reference)) Ok((graph, reference))
} }
#[cfg(feature = "burn")]
struct ExampleBurnModel<B: AutodiffBackend> { struct ExampleBurnModel<B: AutodiffBackend> {
_backend: PhantomData<B>, _backend: PhantomData<B>,
} }
#[cfg(feature = "burn")]
impl<B: AutodiffBackend> ExampleBurnModel<B> { impl<B: AutodiffBackend> ExampleBurnModel<B> {
fn forward_template(&self, _graph: &riemann_flow_gnn::graph::GraphFeatures) -> Result<(), String> { fn forward_template(&self, _graph: &riemann_flow_gnn::graph::GraphFeatures) -> Result<(), String> {
// TODO(user): Burn tensor 입력(GraphTensors, PoseTensors) 기반 forward 구현 // TODO(user): Burn tensor 입력(GraphTensors, PoseTensors) 기반 forward 구현
@@ -123,7 +116,6 @@ impl<B: AutodiffBackend> ExampleBurnModel<B> {
} }
} }
#[cfg(feature = "burn")]
impl<B: AutodiffBackend> BurnToyModel<B> for ExampleBurnModel<B> { impl<B: AutodiffBackend> BurnToyModel<B> for ExampleBurnModel<B> {
fn name(&self) -> &'static str { fn name(&self) -> &'static str {
"burn-toy-model" "burn-toy-model"
@@ -196,7 +188,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}); });
let graph_features = pipeline.graph_contract(); let graph_features = pipeline.graph_contract();
#[cfg(feature = "burn")]
let history = { let history = {
let device = Default::default(); let device = Default::default();
let mut model = ExampleBurnModel::<ExampleBackend>::new(&device); let mut model = ExampleBurnModel::<ExampleBackend>::new(&device);
@@ -233,16 +224,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
graph_features.spec.node_feat_dim, graph_features.spec.node_feat_dim,
graph_features.spec.edge_feat_dim graph_features.spec.edge_feat_dim
); );
#[cfg(feature = "burn")]
println!( println!(
"burn training finished: epochs={}, last_loss={:.6}", "burn training finished: epochs={}, last_loss={:.6}",
history.len(), history.len(),
history.last().map(|m| m.loss).unwrap_or(0.0) history.last().map(|m| m.loss).unwrap_or(0.0)
); );
#[cfg(feature = "burn")]
println!("saved checkpoints: outputs/checkpoints/*.safetensors"); println!("saved checkpoints: outputs/checkpoints/*.safetensors");
#[cfg(not(feature = "burn"))]
println!("burn feature is off: training skeleton skipped");
println!("saved video: outputs/simulation.mp4"); println!("saved video: outputs/simulation.mp4");
println!("target reference: data/toy/aspirin_reference.json"); println!("target reference: data/toy/aspirin_reference.json");
Ok(()) Ok(())

20
scripts/gpu_env_check.sh Executable file
View File

@@ -0,0 +1,20 @@
#!/usr/bin/env bash
set -euo pipefail
echo "[1/4] NVIDIA driver check"
nvidia-smi
echo
echo "[2/4] CUDA compiler check"
nvcc --version
echo
echo "[3/4] PyTorch CUDA runtime check"
python "scripts/gpu_smoke_test.py"
echo
echo "[4/4] Burn feature compile check"
cargo check --example toy_pipeline
echo
echo "GPU environment check finished successfully."

23
scripts/gpu_smoke_test.py Normal file
View File

@@ -0,0 +1,23 @@
import torch
def main() -> None:
print(f"torch={torch.__version__}")
print(f"cuda_available={torch.cuda.is_available()}")
if not torch.cuda.is_available():
raise SystemExit("CUDA is not available")
device_count = torch.cuda.device_count()
print(f"cuda_device_count={device_count}")
for i in range(device_count):
print(f"gpu[{i}]={torch.cuda.get_device_name(i)}")
x = torch.randn((2048, 2048), device="cuda")
y = torch.randn((2048, 2048), device="cuda")
z = x @ y
torch.cuda.synchronize()
print(f"matmul_ok shape={tuple(z.shape)} dtype={z.dtype} device={z.device}")
if __name__ == "__main__":
main()

View File

@@ -46,7 +46,6 @@ impl OneSampleToyPipeline {
} }
} }
#[cfg(feature = "burn")]
pub mod burn_training { pub mod burn_training {
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,

View File

@@ -1,12 +1,8 @@
#[cfg(feature = "burn")]
use burn::tensor::{backend::Backend, Int, Tensor}; use burn::tensor::{backend::Backend, Int, Tensor};
#[cfg(feature = "burn")]
use crate::geometry::PoseState; use crate::geometry::PoseState;
#[cfg(feature = "burn")]
use crate::graph::GraphFeatures; use crate::graph::GraphFeatures;
#[cfg(feature = "burn")]
pub struct GraphTensors<B: Backend> { pub struct GraphTensors<B: Backend> {
pub node_features: Tensor<B, 2>, pub node_features: Tensor<B, 2>,
pub edge_index: Tensor<B, 2, Int>, pub edge_index: Tensor<B, 2, Int>,
@@ -14,7 +10,6 @@ pub struct GraphTensors<B: Backend> {
pub torsion_edge_index: Tensor<B, 2, Int>, pub torsion_edge_index: Tensor<B, 2, Int>,
} }
#[cfg(feature = "burn")]
impl<B: Backend> GraphTensors<B> { impl<B: Backend> GraphTensors<B> {
pub fn from_features(features: &GraphFeatures, device: &B::Device) -> Self { pub fn from_features(features: &GraphFeatures, device: &B::Device) -> Self {
Self { Self {
@@ -35,14 +30,12 @@ impl<B: Backend> GraphTensors<B> {
} }
} }
#[cfg(feature = "burn")]
pub struct PoseTensors<B: Backend> { pub struct PoseTensors<B: Backend> {
pub translation: Tensor<B, 2>, pub translation: Tensor<B, 2>,
pub rotation: Tensor<B, 2>, pub rotation: Tensor<B, 2>,
pub torsions: Tensor<B, 2>, pub torsions: Tensor<B, 2>,
} }
#[cfg(feature = "burn")]
impl<B: Backend> PoseTensors<B> { impl<B: Backend> PoseTensors<B> {
pub fn from_pose_batch(poses: &[PoseState], device: &B::Device) -> Self { pub fn from_pose_batch(poses: &[PoseState], device: &B::Device) -> Self {
let batch = poses.len(); let batch = poses.len();