Update project configuration and documentation for GPU support; add GPU environment check script and CUDA runtime validation
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,3 +1,5 @@
|
|||||||
|
outputs
|
||||||
|
|
||||||
# ---> Rust
|
# ---> Rust
|
||||||
# Generated by Cargo
|
# Generated by Cargo
|
||||||
# will have compiled files and executables
|
# will have compiled files and executables
|
||||||
|
|||||||
@@ -3,17 +3,13 @@ name = "riemann-flow-gnn"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[features]
|
|
||||||
default = []
|
|
||||||
burn = ["dep:burn", "dep:safetensors"]
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
plotters = { version = "0.3", default-features = false, features = ["bitmap_backend", "bitmap_encoder"] }
|
plotters = { version = "0.3", default-features = false, features = ["bitmap_backend", "bitmap_encoder"] }
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
burn = { version = "0.14", optional = true, features = ["autodiff", "ndarray"] }
|
burn = { version = "0.14", features = ["autodiff", "cuda-jit"] }
|
||||||
safetensors = { version = "0.4", optional = true }
|
safetensors = { version = "0.4" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
approx = "0.5"
|
approx = "0.5"
|
||||||
|
|||||||
16
README.md
16
README.md
@@ -13,6 +13,20 @@ Riemannian Flow matching with GNN implementation in rust/burn
|
|||||||
cargo run --example toy_pipeline
|
cargo run --example toy_pipeline
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### NVIDIA GPU 환경 점검
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash scripts/gpu_env_check.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
점검 항목:
|
||||||
|
- `nvidia-smi` 드라이버/디바이스 인식
|
||||||
|
- `nvcc --version` CUDA toolkit 인식
|
||||||
|
- `scripts/gpu_smoke_test.py` CUDA matmul 런타임 실행
|
||||||
|
- `cargo check --example toy_pipeline` Burn 빌드 스모크 테스트
|
||||||
|
|
||||||
|
예제 학습 스켈레톤은 **기본으로 Burn CUDA JIT 백엔드(GPU)** 를 사용합니다. NVIDIA 드라이버와 CUDA 개발 환경(`nvcc` 등)이 필요합니다.
|
||||||
|
|
||||||
예제는 파이프라인 계약 확인용으로 동작합니다.
|
예제는 파이프라인 계약 확인용으로 동작합니다.
|
||||||
|
|
||||||
### 현재 포함 범위
|
### 현재 포함 범위
|
||||||
@@ -20,7 +34,7 @@ cargo run --example toy_pipeline
|
|||||||
- graph/domain struct (`Atom`, `Bond`, `TorsionEdge`, `LigandGraph`)
|
- graph/domain struct (`Atom`, `Bond`, `TorsionEdge`, `LigandGraph`)
|
||||||
- one-sample entry API (`api::OneSampleToyPipeline`)
|
- one-sample entry API (`api::OneSampleToyPipeline`)
|
||||||
- graph/pose/simulator 최소 계약
|
- graph/pose/simulator 최소 계약
|
||||||
- Burn tensor shape 중심 adapter skeleton (`--features burn`)
|
- Burn tensor shape 중심 adapter skeleton (기본 활성)
|
||||||
- TODO 기반 구현 포인트 주석
|
- TODO 기반 구현 포인트 주석
|
||||||
|
|
||||||
### 이후 단계 권장
|
### 이후 단계 권장
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
#[cfg(feature = "burn")]
|
|
||||||
use std::{collections::BTreeMap, marker::PhantomData, path::Path};
|
use std::{collections::BTreeMap, marker::PhantomData, path::Path};
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
use burn::{backend::Autodiff, tensor::backend::AutodiffBackend};
|
use burn::{backend::Autodiff, tensor::backend::AutodiffBackend};
|
||||||
use riemann_flow_gnn::{
|
use riemann_flow_gnn::{
|
||||||
api::{OneSampleToyBatch, OneSampleToyPipeline},
|
api::{OneSampleToyBatch, OneSampleToyPipeline},
|
||||||
@@ -12,18 +10,15 @@ use riemann_flow_gnn::{
|
|||||||
viz::{export_simulation_video, plot_static_png},
|
viz::{export_simulation_video, plot_static_png},
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
use riemann_flow_gnn::api::burn_training::{
|
use riemann_flow_gnn::api::burn_training::{
|
||||||
run_burn_training, BurnToyModel, BurnTrainConfig,
|
run_burn_training, BurnToyModel, BurnTrainConfig,
|
||||||
};
|
};
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
use safetensors::{
|
use safetensors::{
|
||||||
serialize_to_file,
|
serialize_to_file,
|
||||||
tensor::{Dtype, TensorView},
|
tensor::{Dtype, TensorView},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(feature = "burn")]
|
type ExampleBackend = Autodiff<burn::backend::CudaJit>;
|
||||||
type ExampleBackend = Autodiff<burn::backend::NdArray>;
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ToyAtom {
|
struct ToyAtom {
|
||||||
@@ -100,12 +95,10 @@ fn load_toy_reference(path: &str) -> Result<(LigandGraph, LigandConformer), Box<
|
|||||||
Ok((graph, reference))
|
Ok((graph, reference))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
struct ExampleBurnModel<B: AutodiffBackend> {
|
struct ExampleBurnModel<B: AutodiffBackend> {
|
||||||
_backend: PhantomData<B>,
|
_backend: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
impl<B: AutodiffBackend> ExampleBurnModel<B> {
|
impl<B: AutodiffBackend> ExampleBurnModel<B> {
|
||||||
fn forward_template(&self, _graph: &riemann_flow_gnn::graph::GraphFeatures) -> Result<(), String> {
|
fn forward_template(&self, _graph: &riemann_flow_gnn::graph::GraphFeatures) -> Result<(), String> {
|
||||||
// TODO(user): Burn tensor 입력(GraphTensors, PoseTensors) 기반 forward 구현
|
// TODO(user): Burn tensor 입력(GraphTensors, PoseTensors) 기반 forward 구현
|
||||||
@@ -123,7 +116,6 @@ impl<B: AutodiffBackend> ExampleBurnModel<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
impl<B: AutodiffBackend> BurnToyModel<B> for ExampleBurnModel<B> {
|
impl<B: AutodiffBackend> BurnToyModel<B> for ExampleBurnModel<B> {
|
||||||
fn name(&self) -> &'static str {
|
fn name(&self) -> &'static str {
|
||||||
"burn-toy-model"
|
"burn-toy-model"
|
||||||
@@ -196,7 +188,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
});
|
});
|
||||||
let graph_features = pipeline.graph_contract();
|
let graph_features = pipeline.graph_contract();
|
||||||
|
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
let history = {
|
let history = {
|
||||||
let device = Default::default();
|
let device = Default::default();
|
||||||
let mut model = ExampleBurnModel::<ExampleBackend>::new(&device);
|
let mut model = ExampleBurnModel::<ExampleBackend>::new(&device);
|
||||||
@@ -233,16 +224,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
graph_features.spec.node_feat_dim,
|
graph_features.spec.node_feat_dim,
|
||||||
graph_features.spec.edge_feat_dim
|
graph_features.spec.edge_feat_dim
|
||||||
);
|
);
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
println!(
|
println!(
|
||||||
"burn training finished: epochs={}, last_loss={:.6}",
|
"burn training finished: epochs={}, last_loss={:.6}",
|
||||||
history.len(),
|
history.len(),
|
||||||
history.last().map(|m| m.loss).unwrap_or(0.0)
|
history.last().map(|m| m.loss).unwrap_or(0.0)
|
||||||
);
|
);
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
println!("saved checkpoints: outputs/checkpoints/*.safetensors");
|
println!("saved checkpoints: outputs/checkpoints/*.safetensors");
|
||||||
#[cfg(not(feature = "burn"))]
|
|
||||||
println!("burn feature is off: training skeleton skipped");
|
|
||||||
println!("saved video: outputs/simulation.mp4");
|
println!("saved video: outputs/simulation.mp4");
|
||||||
println!("target reference: data/toy/aspirin_reference.json");
|
println!("target reference: data/toy/aspirin_reference.json");
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
20
scripts/gpu_env_check.sh
Executable file
20
scripts/gpu_env_check.sh
Executable file
@@ -0,0 +1,20 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
echo "[1/4] NVIDIA driver check"
|
||||||
|
nvidia-smi
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "[2/4] CUDA compiler check"
|
||||||
|
nvcc --version
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "[3/4] PyTorch CUDA runtime check"
|
||||||
|
python "scripts/gpu_smoke_test.py"
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "[4/4] Burn feature compile check"
|
||||||
|
cargo check --example toy_pipeline
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "GPU environment check finished successfully."
|
||||||
23
scripts/gpu_smoke_test.py
Normal file
23
scripts/gpu_smoke_test.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
print(f"torch={torch.__version__}")
|
||||||
|
print(f"cuda_available={torch.cuda.is_available()}")
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise SystemExit("CUDA is not available")
|
||||||
|
|
||||||
|
device_count = torch.cuda.device_count()
|
||||||
|
print(f"cuda_device_count={device_count}")
|
||||||
|
for i in range(device_count):
|
||||||
|
print(f"gpu[{i}]={torch.cuda.get_device_name(i)}")
|
||||||
|
|
||||||
|
x = torch.randn((2048, 2048), device="cuda")
|
||||||
|
y = torch.randn((2048, 2048), device="cuda")
|
||||||
|
z = x @ y
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print(f"matmul_ok shape={tuple(z.shape)} dtype={z.dtype} device={z.device}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -46,7 +46,6 @@ impl OneSampleToyPipeline {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
pub mod burn_training {
|
pub mod burn_training {
|
||||||
use std::{
|
use std::{
|
||||||
collections::BTreeMap,
|
collections::BTreeMap,
|
||||||
|
|||||||
@@ -1,12 +1,8 @@
|
|||||||
#[cfg(feature = "burn")]
|
|
||||||
use burn::tensor::{backend::Backend, Int, Tensor};
|
use burn::tensor::{backend::Backend, Int, Tensor};
|
||||||
|
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
use crate::geometry::PoseState;
|
use crate::geometry::PoseState;
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
use crate::graph::GraphFeatures;
|
use crate::graph::GraphFeatures;
|
||||||
|
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
pub struct GraphTensors<B: Backend> {
|
pub struct GraphTensors<B: Backend> {
|
||||||
pub node_features: Tensor<B, 2>,
|
pub node_features: Tensor<B, 2>,
|
||||||
pub edge_index: Tensor<B, 2, Int>,
|
pub edge_index: Tensor<B, 2, Int>,
|
||||||
@@ -14,7 +10,6 @@ pub struct GraphTensors<B: Backend> {
|
|||||||
pub torsion_edge_index: Tensor<B, 2, Int>,
|
pub torsion_edge_index: Tensor<B, 2, Int>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
impl<B: Backend> GraphTensors<B> {
|
impl<B: Backend> GraphTensors<B> {
|
||||||
pub fn from_features(features: &GraphFeatures, device: &B::Device) -> Self {
|
pub fn from_features(features: &GraphFeatures, device: &B::Device) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -35,14 +30,12 @@ impl<B: Backend> GraphTensors<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
pub struct PoseTensors<B: Backend> {
|
pub struct PoseTensors<B: Backend> {
|
||||||
pub translation: Tensor<B, 2>,
|
pub translation: Tensor<B, 2>,
|
||||||
pub rotation: Tensor<B, 2>,
|
pub rotation: Tensor<B, 2>,
|
||||||
pub torsions: Tensor<B, 2>,
|
pub torsions: Tensor<B, 2>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "burn")]
|
|
||||||
impl<B: Backend> PoseTensors<B> {
|
impl<B: Backend> PoseTensors<B> {
|
||||||
pub fn from_pose_batch(poses: &[PoseState], device: &B::Device) -> Self {
|
pub fn from_pose_batch(poses: &[PoseState], device: &B::Device) -> Self {
|
||||||
let batch = poses.len();
|
let batch = poses.len();
|
||||||
|
|||||||
Reference in New Issue
Block a user