Add initial implementation of Riemannian Flow GNN with essential data structures, simulation logic, and visualization utilities

This commit is contained in:
demian3b
2026-04-15 23:19:53 +09:00
parent b1b1ae6229
commit 9e45a090af
15 changed files with 7000 additions and 1 deletions

5022
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

19
Cargo.toml Normal file
View File

@@ -0,0 +1,19 @@
[package]
name = "riemann-flow-gnn"
version = "0.1.0"
edition = "2021"
[features]
default = []
burn = ["dep:burn", "dep:safetensors"]
[dependencies]
plotters = { version = "0.3", default-features = false, features = ["bitmap_backend", "bitmap_encoder"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "1"
burn = { version = "0.14", optional = true, features = ["autodiff", "ndarray"] }
safetensors = { version = "0.4", optional = true }
[dev-dependencies]
approx = "0.5"

View File

@@ -1,3 +1,33 @@
# riemann-flow-gnn
Riemannian Flow matching with GNN implementation in rust/burn
## Toy prototype skeleton (API-first)
이 저장소는 `R^3 x SO(3) x T^m` one-sample overfitting 실험을 위한 **API 계약 중심 스켈레톤**을 포함합니다.
내부 수학/시뮬레이션 로직은 의도적으로 최소화되어 있으며, TODO 지점부터 실험 코드를 채우는 구조입니다.
### 실행
```bash
cargo run --example toy_pipeline
```
예제는 파이프라인 계약 확인용으로 동작합니다.
### 현재 포함 범위
- graph/domain struct (`Atom`, `Bond`, `TorsionEdge`, `LigandGraph`)
- one-sample entry API (`api::OneSampleToyPipeline`)
- graph/pose/simulator 최소 계약
- Burn tensor shape 중심 adapter skeleton (`--features burn`)
- TODO 기반 구현 포인트 주석
### 이후 단계 권장
1. random init sampling 테스트
2. translation/rotation/torsion 단위 테스트 확장
3. GNN 모델 정의 및 입력 feature 고정
4. geodesic path/velocity supervision 연결
5. 학습 루프/실행 스크립트 추가
6. epoch별 trajectory 저장 자동화

View File

@@ -0,0 +1,697 @@
# Burn 기반 Pose Manifold Graph 구현 AI 전달 문서
## 0. 문서 목적
이 문서는 **단일 ligand pose overfitting 실험**을 위해, AI 에이전트가 구현해야 할 범위를 명확히 정의한 handoff 문서이다.
구현 대상은 다음 세 가지다.
1. **Burn 규격에 맞는 manifold tensor 표현 및 graph 구현체**
2. **graph 입력을 받아 시각화 가능한 구조 표현과 렌더링 파이프라인**
3. **time parameter `t ∈ [0, 1]`에 따라 translation / rotation / torsion이 변하는 simulator**
이번 단계의 목표는 **연구용 완성품**이 아니라, 발표 및 초기 검증을 위한 **작동 가능한 prototype**이다.
---
## 1. 문제 설정
### 1.1 상태 공간
Ligand pose state는 아래 product space로 둔다.
- `translation`: `R^3`
- `rotation`: `SO(3)`
- `torsion`: `T^m` where `m = number of rotatable bonds`
즉 하나의 pose state는 개념적으로 아래와 같다.
```text
x = (p, R, theta)
p : global translation, shape [3]
R : global rotation
theta : torsion angles for rotatable bonds, shape [m]
```
### 1.2 그래프 입력
입력 ligand는 graph로 표현한다.
- node = atom
- edge = bond
- torsion-available edge = 회전 가능한 bond
이번 단계에서는 **protein / pocket context는 제외**한다.
오직 ligand 단일 graph와 target pose만 다룬다.
### 1.3 실험 목적
- 주어진 ligand graph 하나에 대해
- 임의의 초기 pose에서 출발해
- 하나의 fixed target pose로 수렴하는 trajectory를 생성하고
- 이를 시각화 가능하게 만든다.
학습 로직 전체를 이 문서에서 구현하라는 뜻은 아니다.
이번 구현의 중심은 **data representation + simulator + visualization**이다.
---
## 2. 구현 범위 요약
AI 에이전트는 아래 항목을 구현한다.
### 필수 구현
- Burn-friendly tensor container
- ligand graph data structure
- torsion metadata extraction/representation
- pose state container
- explicit simulator for `t=0..1`
- 3D/2D visualization export utility
### 이번 단계에서 제외
- full training loop
- optimizer / scheduler
- dataset loader for many molecules
- protein-ligand interaction scoring
- equivariant GNN model 자체
---
## 3. 설계 원칙
### 3.1 Burn 친화적 설계
구현체는 Burn 모델과 쉽게 연결될 수 있어야 한다.
따라서 다음 원칙을 지킨다.
- **tensor payload는 Burn tensor로 쉽게 변환 가능**해야 한다.
- 구조체는 geometry / graph / rendering 책임을 분리한다.
- manifold-specific logic는 순수 수학 유틸로 분리한다.
- 시각화는 모델과 독립적인 모듈로 둔다.
### 3.2 API 변동에 덜 민감한 설계
Burn 버전 변화에 덜 흔들리도록, 내부 표현과 backend 종속성을 너무 깊게 섞지 않는다.
권장 방식:
- domain object는 plain Rust struct로 유지
- 필요 시 `to_tensor::<B: Backend>(&device)` 같은 adapter 제공
- 학습용 Autodiff backend와 추론용 backend를 쉽게 교체할 수 있게 설계
### 3.3 발표용 prototype 우선
이번 구현은 성능 최적화보다 아래가 중요하다.
- 구조가 명확할 것
- simulator가 deterministic하게 잘 동작할 것
- geometry가 눈으로 드러날 것
---
## 4. 구현해야 할 핵심 데이터 구조
### 4.1 Atom / Bond / LigandGraph
#### Atom
최소 필드:
```rust
pub struct Atom {
pub index: usize,
pub atomic_number: u8,
pub formal_charge: i8,
pub is_aromatic: bool,
}
```
#### Bond
최소 필드:
```rust
pub struct Bond {
pub index: usize,
pub src: usize,
pub dst: usize,
pub bond_order: u8,
pub is_aromatic: bool,
pub is_rotatable: bool,
}
```
#### TorsionEdge
회전 가능한 bond에 대해, 어떤 원자 집합이 어느 쪽 fragment로 회전하는지 알아야 한다.
이 metadata가 simulator에서 중요하다.
```rust
pub struct TorsionEdge {
pub bond_index: usize,
pub atom_left: usize,
pub atom_right: usize,
pub rotating_side: Vec<usize>,
}
```
설명:
- `atom_left`, `atom_right`는 torsion axis를 이루는 bond의 두 원자
- `rotating_side`는 회전 적용 대상 atom index 목록
#### LigandGraph
```rust
pub struct LigandGraph {
pub atoms: Vec<Atom>,
pub bonds: Vec<Bond>,
pub edge_index: Vec<(usize, usize)>,
pub torsion_edges: Vec<TorsionEdge>,
}
```
요구사항:
- undirected bond graph를 기본으로 하되
- GNN 입력용으로는 directed edge pair로 변환 가능해야 함
- torsion edge list를 별도로 유지해야 함
---
### 4.2 PoseState
pose state는 ligand의 geometry 상태를 담는다.
#### 권장 표현
```rust
pub struct PoseState {
pub translation: [f32; 3],
pub rotation_quat_xyzw: [f32; 4],
pub torsions: Vec<f32>,
}
```
주의:
- quaternion은 normalize 보장 필요
- torsion은 radian 단위
- torsion 범위는 기본적으로 `[-pi, pi)` 로 유지
rotation은 내부적으로 quaternion으로 저장하되,
필요 시 axis-angle / rotation matrix 변환 유틸을 제공한다.
---
### 4.3 LigandConformer
실제 원자 좌표를 가진 구조 표현.
```rust
pub struct LigandConformer {
pub graph: LigandGraph,
pub coords: Vec<[f32; 3]>,
}
```
요구사항:
- `coords.len() == graph.atoms.len()` 이어야 함
- reference conformer를 기준으로 pose transform 적용 가능해야 함
---
## 5. Burn에 연결 가능한 tensor adapter
AI 에이전트는 아래와 같은 방향의 adapter를 제공한다.
```rust
pub struct GraphTensors<B: burn::tensor::backend::Backend> {
pub node_features: burn::tensor::Tensor<B, 2>,
pub edge_index: burn::tensor::Tensor<B, 2, burn::tensor::Int>,
pub edge_features: Option<burn::tensor::Tensor<B, 2>>,
pub torsion_edge_index: burn::tensor::Tensor<B, 2, burn::tensor::Int>,
}
```
목표는 정확한 최종 타입을 박제하는 것이 아니라, 아래 contract를 만족하는 것이다.
### contract
- node features: `[num_nodes, node_feat_dim]`
- edge index: `[2, num_edges]`
- torsion edge index: `[2, num_rotatable_bonds]` 또는 equivalent representation
- pose tensor export: translation / quaternion / torsion을 batchable tensor로 변환 가능
### node feature 최소 구성
최초 버전은 아래 정도면 충분하다.
- atomic number
- formal charge
- aromatic flag
- degree(optional)
### edge feature 최소 구성
- bond order
- aromatic flag
- rotatable flag
### pose tensor export
```rust
pub struct PoseTensors<B: burn::tensor::backend::Backend> {
pub translation: burn::tensor::Tensor<B, 2>, // [batch, 3]
pub rotation: burn::tensor::Tensor<B, 2>, // [batch, 4] quaternion
pub torsions: burn::tensor::Tensor<B, 2>, // [batch, m]
}
```
---
## 6. 구현해야 할 manifold math 유틸
이 모듈은 모델과 완전히 분리된 순수 수학 유틸이어야 한다.
파일 예시:
```text
src/geometry/
euclidean.rs
so3.rs
torus.rs
pose.rs
```
### 6.1 Translation (`R^3`)
필수 함수:
- `lerp_translation(p0, p1, t)`
- `translation_delta(p0, p1)`
### 6.2 Rotation (`SO(3)`)
필수 함수:
- quaternion normalize
- quaternion multiply / inverse
- quaternion to rotation matrix
- rotation matrix to quaternion (optional)
- `slerp(q0, q1, t)`
- `relative_rotation(q_from, q_to)`
- `so3_log(q_rel) -> [f32; 3]`
- `so3_exp(omega) -> quat`
설명:
- simulator에서는 `slerp` 기반 interpolation만 먼저 구현해도 충분
- 이후 tangent vector supervision을 위해 `log/exp` utility 유지 권장
### 6.3 Torsion (`T^m`)
필수 함수:
- `wrap_pi(x)`
- `wrap_vec(theta)`
- `shortest_angle_delta(a, b)`
- `interpolate_torsion(theta0, theta1, t)`
정의:
```text
shortest_angle_delta(a, b) = wrap_pi(b - a)
interpolate_torsion(theta0, theta1, t) = wrap(theta0 + t * shortest_delta)
```
### 6.4 Product Pose Geometry
필수 함수:
- `interpolate_pose(pose0, pose1, t) -> PoseState`
- `pose_delta(pose0, pose1) -> PoseDelta`
예상 구조:
```rust
pub struct PoseDelta {
pub translation: [f32; 3],
pub rotation_tangent: [f32; 3],
pub torsions: Vec<f32>,
}
```
---
## 7. 좌표 복원 / simulator 요구사항
핵심 요구는 다음과 같다.
> `reference conformer + pose state` 를 받아, 시각화 가능한 현재 좌표를 만든다.
### 7.1 적용 순서
권장 적용 순서:
1. reference conformer 좌표 복사
2. rotatable bond별 torsion 회전 적용
3. global rotation 적용
4. global translation 적용
### 7.2 torsion 회전 적용
`TorsionEdge`에 대해:
- axis = bond direction
- pivot = bond axis 상의 원자 위치
- rotating_side에 포함된 atom들만 회전
- angle = 해당 torsion value
즉, torsion은 local internal transform이다.
### 7.3 global transform 적용
- rotation quaternion으로 전체 원자 좌표 회전
- translation vector 더하기
### 7.4 simulator
입력:
- `reference conformer`
- `pose_start`
- `pose_target`
- `t in [0,1]`
출력:
- `PoseState at t`
- `LigandConformer at t`
즉,
```rust
pub fn simulate_at_t(
reference: &LigandConformer,
pose_start: &PoseState,
pose_target: &PoseState,
t: f32,
) -> LigandConformer
```
형태의 API가 있어야 한다.
### 7.5 trajectory
추가로 아래 helper도 구현한다.
```rust
pub fn simulate_trajectory(
reference: &LigandConformer,
pose_start: &PoseState,
pose_target: &PoseState,
num_steps: usize,
) -> Vec<LigandConformer>
```
- `t = 0..1` 균등 분할
- endpoint 포함
- 결과는 animation 또는 프레임 저장에 활용 가능
---
## 8. Visualization 요구사항
이번 단계에서 visualization은 매우 중요하다.
발표용 데모에 직접 쓰일 수 있어야 한다.
### 필수 출력
- ligand graph connectivity 확인용 2D/3D static view
- current pose 단일 프레임 렌더링
- trajectory animation용 frame sequence export
### 최소 기능
#### A. static plot
입력:
- atom coordinates
- bond list
출력:
- atoms = points
- bonds = lines
- rotatable bond는 다른 색상 또는 강조
#### B. trajectory frames
입력:
- `Vec<LigandConformer>`
출력:
- png sequence 또는 gif용 intermediate frames
#### C. optional: HTML viewer
가능하면 아래 중 하나 추가 검토:
- plotly 기반 HTML
- three.js/json export용 포맷
- simple `.xyz` / `.sdf` frame export
### 시각화 우선순위
1. 정적 PNG 출력
2. trajectory frame sequence 출력
3. 선택적으로 gif/mp4/HTML
주의:
- visualization 구현은 model과 완전히 분리
- geometry correctness 검증이 목적이지 예쁜 UI가 목적은 아님
---
## 9. 권장 디렉터리 구조
```text
src/
graph/
atom.rs
bond.rs
torsion.rs
ligand_graph.rs
featurize.rs
geometry/
euclidean.rs
so3.rs
torus.rs
pose.rs
conformer/
conformer.rs
apply_torsion.rs
apply_pose.rs
burn_adapter/
graph_tensors.rs
pose_tensors.rs
simulator/
interpolate.rs
trajectory.rs
viz/
plot_static.rs
plot_trajectory.rs
export.rs
lib.rs
```
---
## 10. 구현 상세 계약
### 10.1 반드시 보장해야 할 것
- quaternion normalization이 항상 유지될 것
- torsion angle은 항상 wrapped range에 있을 것
- `rotating_side` 계산이 일관될 것
- bond axis 기준 회전이 올바를 것
- `simulate_at_t(..., 0.0)` 는 시작 pose와 일치할 것
- `simulate_at_t(..., 1.0)` 는 타깃 pose와 일치할 것
### 10.2 numerical sanity checks
AI 에이전트는 아래 테스트를 함께 작성한다.
#### Test 1: wrap correctness
- `pi - eps``-pi + eps` 의 차이가 작은 값으로 처리되는지 검증
#### Test 2: quaternion norm
- interpolation 전후 quaternion norm ≈ 1
#### Test 3: endpoint correctness
- trajectory 첫 프레임과 마지막 프레임이 입력 pose와 일치
#### Test 4: rigid transform invariance
- torsion 없는 molecule에서는 internal geometry가 rigid하게 유지
#### Test 5: torsion rotation locality
- 특정 torsion 회전 시 rotating_side atom만 변하는지 확인
---
## 11. 입력 데이터 가정
이번 구현은 범용 loader까지 강제하지 않는다.
아래 두 입력 경로 중 하나만 우선 지원하면 된다.
### 경로 A: 수동 구성
- atom list
- bond list
- reference coordinates
- rotatable bond metadata
- target pose
### 경로 B: RDKit 등 외부 전처리 결과 import
- JSON 또는 serde-friendly 포맷으로 graph + conformer metadata load
권장:
- Rust 내부 구현은 외부 화학 toolkit에 강하게 묶지 말 것
- 외부 전처리 결과를 import하는 구조가 더 안전함
---
## 12. 구현 우선순위
### Phase 1
- `Atom`, `Bond`, `TorsionEdge`, `LigandGraph`
- `PoseState`
- `wrap_pi`, `shortest_angle_delta`
- quaternion normalize + slerp
- `interpolate_pose`
### Phase 2
- `LigandConformer`
- torsion application
- global rotation/translation application
- `simulate_at_t`
### Phase 3
- Burn tensor adapter
- graph/node/edge feature export
- pose tensor export
### Phase 4
- visualization static export
- trajectory frame export
- tests / sanity checks
---
## 13. 기대 산출물
AI 에이전트는 최종적으로 아래를 제공해야 한다.
### 코드 산출물
- Rust crate source
- 기본 unit tests
- example 실행 코드
### example 요구
example은 아래를 보여줘야 한다.
1. ligand graph 생성 또는 로드
2. reference conformer 준비
3. start pose / target pose 설정
4. `t=0.0, 0.25, 0.5, 0.75, 1.0` 프레임 생성
5. static image 또는 trajectory frame export
6. graph tensor / pose tensor 변환 예시
---
## 14. AI 에이전트에게 주는 구현 지침
### 반드시 지킬 것
- 너무 많은 abstraction을 한 번에 넣지 말 것
- 우선 **작동하는 명시적 구현**을 만들 것
- geometry correctness를 최우선으로 둘 것
- Burn 모델 학습 코드까지 과도하게 확장하지 말 것
### 선호하는 구현 스타일
- plain Rust struct 중심
- 명시적 타입
- unit test 풍부하게
- 실패 시 panic보다 `Result` 기반 오류 처리 선호
- visualization/output path는 예제에서 바로 실행 가능하게 구성
### 금지 사항
- 불필요한 거대 프레임워크 도입
- protein context까지 무리하게 확장
- 학습 코드와 시각화 코드를 강하게 결합
- 화학 toolkit 종속성을 core domain object 안에 깊게 삽입
---
## 15. 구현 완료 판정 기준
아래 조건을 만족하면 이번 단계 구현 완료로 본다.
- ligand graph와 torsion metadata를 표현할 수 있음
- pose state를 translation / rotation / torsion으로 표현할 수 있음
- `t=0..1` explicit interpolation simulator가 동작함
- conformer 좌표를 올바르게 복원함
- static frame 및 trajectory frame을 export할 수 있음
- Burn 입력용 tensor adapter가 존재함
- 최소한의 numerical sanity tests가 모두 통과함
---
## 16. 선택적 확장 아이디어
이번 단계 이후 확장 후보:
- batch support
- pocket context 추가
- learned vector field model 연결
- explicit path supervision dataset 생성기
- SE(3)-equivariant / graph neural model 연결
- target pose 단일 개체가 아니라 multiple conformer target 지원
---
## 17. 마지막 메모
이 구현은 최종 모델이 아니라, 아래 질문에 답하기 위한 prototype이다.
> “ligand pose manifold (`R^3 × SO(3) × T^m`) 위의 상태를 Rust/Burn 친화적으로 표현하고,
> graph 입력과 연결하며,
> explicit path simulator와 visualization까지 한 번에 돌릴 수 있는가?”
이 질문에 yes를 만들 수 있으면 이번 단계는 충분히 성공이다.

View File

@@ -0,0 +1,343 @@
{
"smiles": "CC(=O)OC1=CC=CC=C1C(O)=O",
"generator": {
"tool": "rdkit",
"rdkit_version": "2025.09.4",
"method": "ETKDGv3 + UFFOptimizeMolecule",
"seed": 20260415
},
"atoms": [
{
"index": 0,
"atomic_number": 6,
"formal_charge": 0,
"is_aromatic": false
},
{
"index": 1,
"atomic_number": 6,
"formal_charge": 0,
"is_aromatic": false
},
{
"index": 2,
"atomic_number": 8,
"formal_charge": 0,
"is_aromatic": false
},
{
"index": 3,
"atomic_number": 8,
"formal_charge": 0,
"is_aromatic": false
},
{
"index": 4,
"atomic_number": 6,
"formal_charge": 0,
"is_aromatic": true
},
{
"index": 5,
"atomic_number": 6,
"formal_charge": 0,
"is_aromatic": true
},
{
"index": 6,
"atomic_number": 6,
"formal_charge": 0,
"is_aromatic": true
},
{
"index": 7,
"atomic_number": 6,
"formal_charge": 0,
"is_aromatic": true
},
{
"index": 8,
"atomic_number": 6,
"formal_charge": 0,
"is_aromatic": true
},
{
"index": 9,
"atomic_number": 6,
"formal_charge": 0,
"is_aromatic": true
},
{
"index": 10,
"atomic_number": 6,
"formal_charge": 0,
"is_aromatic": false
},
{
"index": 11,
"atomic_number": 8,
"formal_charge": 0,
"is_aromatic": false
},
{
"index": 12,
"atomic_number": 8,
"formal_charge": 0,
"is_aromatic": false
}
],
"bonds": [
{
"index": 0,
"src": 0,
"dst": 1,
"bond_order": 1,
"is_aromatic": false,
"is_rotatable": false
},
{
"index": 1,
"src": 1,
"dst": 2,
"bond_order": 2,
"is_aromatic": false,
"is_rotatable": false
},
{
"index": 2,
"src": 1,
"dst": 3,
"bond_order": 1,
"is_aromatic": false,
"is_rotatable": true
},
{
"index": 3,
"src": 3,
"dst": 4,
"bond_order": 1,
"is_aromatic": false,
"is_rotatable": true
},
{
"index": 4,
"src": 4,
"dst": 5,
"bond_order": 2,
"is_aromatic": true,
"is_rotatable": false
},
{
"index": 5,
"src": 5,
"dst": 6,
"bond_order": 2,
"is_aromatic": true,
"is_rotatable": false
},
{
"index": 6,
"src": 6,
"dst": 7,
"bond_order": 2,
"is_aromatic": true,
"is_rotatable": false
},
{
"index": 7,
"src": 7,
"dst": 8,
"bond_order": 2,
"is_aromatic": true,
"is_rotatable": false
},
{
"index": 8,
"src": 8,
"dst": 9,
"bond_order": 2,
"is_aromatic": true,
"is_rotatable": false
},
{
"index": 9,
"src": 9,
"dst": 10,
"bond_order": 1,
"is_aromatic": false,
"is_rotatable": true
},
{
"index": 10,
"src": 10,
"dst": 11,
"bond_order": 1,
"is_aromatic": false,
"is_rotatable": false
},
{
"index": 11,
"src": 10,
"dst": 12,
"bond_order": 2,
"is_aromatic": false,
"is_rotatable": false
},
{
"index": 12,
"src": 9,
"dst": 4,
"bond_order": 2,
"is_aromatic": true,
"is_rotatable": false
}
],
"coords_angstrom": [
[
-2.445027592718335,
-1.67948133887539,
-1.2311204827416857
],
[
-2.3256540791911036,
-0.33513666420602944,
-0.5982962602008224
],
[
-3.2501478181982857,
0.5081276650303239,
-0.7531554024479279
],
[
-1.240796578897732,
-0.07117406663982812,
0.24766934621062192
],
[
-0.8693310896814674,
1.1975591923239997,
0.7212683477007807
],
[
-1.0718661569800187,
2.348408762854174,
-0.0659265674971829
],
[
-0.6879381383599061,
3.605496329692259,
0.40104039015513443
],
[
-0.08809733429840264,
3.732148598191945,
1.6516331751807383
],
[
0.1383074740644653,
2.6011946492493183,
2.4373021116517832
],
[
-0.2413518197061185,
1.3212637178820958,
1.9855525634937683
],
[
0.016171886685239042,
0.13903726167392894,
2.8502637354849125
],
[
-0.38040222522103195,
-1.1365956849618335,
2.4579933961830043
],
[
0.6001216008188676,
0.2801422824805318,
3.9591263627209816
]
],
"bond_lengths_angstrom": [
{
"bond_index": 0,
"src": 0,
"dst": 1,
"length_angstrom": 1.4906304493999016
},
{
"bond_index": 1,
"src": 1,
"dst": 2,
"length_angstrom": 1.26085873767184
},
{
"bond_index": 2,
"src": 1,
"dst": 3,
"length_angstrom": 1.4008032895762923
},
{
"bond_index": 3,
"src": 3,
"dst": 4,
"length_angstrom": 1.4042673200969
},
{
"bond_index": 4,
"src": 4,
"dst": 5,
"length_angstrom": 1.4089538750394368
},
{
"bond_index": 5,
"src": 5,
"dst": 6,
"length_angstrom": 1.3948935492429244
},
{
"bond_index": 6,
"src": 6,
"dst": 7,
"length_angstrom": 1.3927785542904383
},
{
"bond_index": 7,
"src": 7,
"dst": 8,
"length_angstrom": 1.3955614101909997
},
{
"bond_index": 8,
"src": 8,
"dst": 9,
"length_angstrom": 1.4094119421510216
},
{
"bond_index": 9,
"src": 9,
"dst": 10,
"length_angstrom": 1.4871796340988404
},
{
"bond_index": 10,
"src": 10,
"dst": 11,
"length_angstrom": 1.3922594800801036
},
{
"bond_index": 11,
"src": 10,
"dst": 12,
"length_angstrom": 1.2611440130351712
},
{
"bond_index": 12,
"src": 9,
"dst": 4,
"length_angstrom": 1.417065754107979
}
],
"torsion_edges": []
}

230
examples/toy_pipeline.rs Normal file
View File

@@ -0,0 +1,230 @@
#[cfg(feature = "burn")]
use std::{collections::BTreeMap, marker::PhantomData, path::Path};
use std::fs;
#[cfg(feature = "burn")]
use burn::{backend::Autodiff, tensor::backend::AutodiffBackend};
use riemann_flow_gnn::{
api::{OneSampleToyBatch, OneSampleToyPipeline},
conformer::LigandConformer,
geometry::PoseState,
graph::{Atom, Bond, LigandGraph},
viz::{export_simulation_video, plot_static_png},
};
use serde::Deserialize;
#[cfg(feature = "burn")]
use riemann_flow_gnn::api::burn_training::{
run_burn_training, BurnToyModel, BurnTrainConfig,
};
#[cfg(feature = "burn")]
use safetensors::{
serialize_to_file,
tensor::{Dtype, TensorView},
};
#[cfg(feature = "burn")]
type ExampleBackend = Autodiff<burn::backend::NdArray>;
#[derive(Debug, Deserialize)]
struct ToyAtom {
index: usize,
atomic_number: u8,
formal_charge: i8,
is_aromatic: bool,
}
#[derive(Debug, Deserialize)]
struct ToyBond {
index: usize,
src: usize,
dst: usize,
bond_order: u8,
is_aromatic: bool,
is_rotatable: bool,
}
#[derive(Debug, Deserialize)]
struct ToyReferenceData {
atoms: Vec<ToyAtom>,
bonds: Vec<ToyBond>,
coords_angstrom: Vec<[f32; 3]>,
}
fn load_toy_reference(path: &str) -> Result<(LigandGraph, LigandConformer), Box<dyn std::error::Error>> {
let raw = fs::read_to_string(path)?;
let parsed: ToyReferenceData = serde_json::from_str(&raw)?;
let atoms = parsed
.atoms
.into_iter()
.map(|a| Atom {
index: a.index,
atomic_number: a.atomic_number,
formal_charge: a.formal_charge,
is_aromatic: a.is_aromatic,
})
.collect::<Vec<_>>();
let bonds = parsed
.bonds
.into_iter()
.map(|b| Bond {
index: b.index,
src: b.src,
dst: b.dst,
bond_order: b.bond_order,
is_aromatic: b.is_aromatic,
is_rotatable: b.is_rotatable,
})
.collect::<Vec<_>>();
let graph = LigandGraph::new(atoms, bonds, vec![]);
let reference = LigandConformer::new(graph.clone(), parsed.coords_angstrom)?;
Ok((graph, reference))
}
#[cfg(feature = "burn")]
struct ExampleBurnModel<B: AutodiffBackend> {
_backend: PhantomData<B>,
}
#[cfg(feature = "burn")]
impl<B: AutodiffBackend> ExampleBurnModel<B> {
fn forward_template(&self, _graph: &riemann_flow_gnn::graph::GraphFeatures) -> Result<(), String> {
// TODO(user): Burn tensor 입력(GraphTensors, PoseTensors) 기반 forward 구현
Ok(())
}
fn loss_template(&self) -> Result<f64, String> {
// TODO(user): translation/rotation/torsion loss 조합 구현
Ok(1.0)
}
fn optimizer_step_template(&mut self, _lr: f64) -> Result<(), String> {
// TODO(user): Burn optimizer step 구현
Ok(())
}
}
#[cfg(feature = "burn")]
impl<B: AutodiffBackend> BurnToyModel<B> for ExampleBurnModel<B> {
fn name(&self) -> &'static str {
"burn-toy-model"
}
fn new(_device: &B::Device) -> Self {
Self {
_backend: PhantomData,
}
}
fn initialize(
&mut self,
_graph: &riemann_flow_gnn::graph::GraphFeatures,
_device: &B::Device,
) -> Result<(), String> {
// TODO(user): Burn 모델 파라미터 초기화
Ok(())
}
fn train_step(
&mut self,
graph: &riemann_flow_gnn::graph::GraphFeatures,
epoch: usize,
lr: f64,
) -> Result<f64, String> {
self.forward_template(graph)?;
let base_loss = self.loss_template()?;
self.optimizer_step_template(lr)?;
Ok(base_loss / (epoch as f64))
}
fn save_model_safetensors(&self, path: &Path) -> Result<(), String> {
// TODO(user): 실제 model state tensor들을 safetensors로 저장
let value_bytes = 1.0f32.to_le_bytes();
let value_view =
TensorView::new(Dtype::F32, vec![1], &value_bytes).map_err(|e| e.to_string())?;
let tensors = BTreeMap::from([("placeholder_weight".to_string(), value_view)]);
serialize_to_file(tensors, &None, path).map_err(|e| e.to_string())?;
Ok(())
}
fn load_model_safetensors(_device: &B::Device, _path: &Path) -> Result<Self, String> {
// TODO(user): safetensors에서 model state 복원
Ok(Self {
_backend: PhantomData,
})
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let (graph, reference) = load_toy_reference("data/toy/aspirin_reference.json")?;
let pose_start = PoseState {
translation: [0.0, 0.0, 0.0],
rotation_quat_xyzw: [0.0, 0.0, 0.0, 1.0],
torsions: vec![0.0],
};
let pose_target = PoseState {
translation: [1.0, 1.5, 0.5],
rotation_quat_xyzw: [0.0, 0.0, 0.38268343, 0.9238795],
torsions: vec![1.2],
};
let pipeline = OneSampleToyPipeline::new(OneSampleToyBatch {
graph,
reference: reference.clone(),
pose_start,
pose_target,
});
let graph_features = pipeline.graph_contract();
#[cfg(feature = "burn")]
let history = {
let device = Default::default();
let mut model = ExampleBurnModel::<ExampleBackend>::new(&device);
let train_config = BurnTrainConfig {
epochs: 20,
learning_rate: 1e-3,
checkpoint_every: 5,
checkpoint_dir: "outputs/checkpoints".to_string(),
};
run_burn_training::<ExampleBackend, _>(
&mut model,
&graph_features,
&device,
&train_config,
)?
};
// 3) simulation video export
let trajectory = export_simulation_video(
&pipeline.batch.reference,
&pipeline.batch.pose_start,
&pipeline.batch.pose_target,
48,
"outputs",
24,
)?;
plot_static_png(&trajectory[0], "outputs/static_start.png")?;
plot_static_png(trajectory.last().ok_or("empty trajectory")?, "outputs/static_end.png")?;
println!(
"contract: nodes={}, edges={}, torsions={}, node_dim={}, edge_dim={}",
graph_features.num_nodes,
graph_features.num_edges,
graph_features.num_torsion_edges,
graph_features.spec.node_feat_dim,
graph_features.spec.edge_feat_dim
);
#[cfg(feature = "burn")]
println!(
"burn training finished: epochs={}, last_loss={:.6}",
history.len(),
history.last().map(|m| m.loss).unwrap_or(0.0)
);
#[cfg(feature = "burn")]
println!("saved checkpoints: outputs/checkpoints/*.safetensors");
#[cfg(not(feature = "burn"))]
println!("burn feature is off: training skeleton skipped");
println!("saved video: outputs/simulation.mp4");
println!("target reference: data/toy/aspirin_reference.json");
Ok(())
}

194
src/api.rs Normal file
View File

@@ -0,0 +1,194 @@
use crate::{
conformer::{ConformerError, LigandConformer},
geometry::PoseState,
graph::{GraphFeatures, LigandGraph},
simulator::{simulate_at_t, simulate_trajectory},
};
#[derive(Debug, Clone)]
pub struct OneSampleToyBatch {
pub graph: LigandGraph,
pub reference: LigandConformer,
pub pose_start: PoseState,
pub pose_target: PoseState,
}
#[derive(Debug, Clone)]
pub struct OneSampleToyPipeline {
pub batch: OneSampleToyBatch,
}
impl OneSampleToyPipeline {
pub fn new(batch: OneSampleToyBatch) -> Self {
Self { batch }
}
pub fn graph_contract(&self) -> GraphFeatures {
GraphFeatures::from_graph(&self.batch.graph)
}
pub fn sample_at_t(&self, t: f32) -> Result<LigandConformer, ConformerError> {
simulate_at_t(
&self.batch.reference,
&self.batch.pose_start,
&self.batch.pose_target,
t,
)
}
pub fn trajectory(&self, num_steps: usize) -> Result<Vec<LigandConformer>, ConformerError> {
simulate_trajectory(
&self.batch.reference,
&self.batch.pose_start,
&self.batch.pose_target,
num_steps,
)
}
}
#[cfg(feature = "burn")]
pub mod burn_training {
use std::{
collections::BTreeMap,
fs,
path::{Path, PathBuf},
};
use burn::tensor::backend::AutodiffBackend;
use safetensors::{
serialize_to_file,
tensor::{Dtype, TensorView},
};
use thiserror::Error;
use crate::{burn_adapter::GraphTensors, graph::GraphFeatures};
#[derive(Debug, Clone)]
pub struct BurnTrainConfig {
pub epochs: usize,
pub learning_rate: f64,
pub checkpoint_every: usize,
pub checkpoint_dir: String,
}
impl Default for BurnTrainConfig {
fn default() -> Self {
Self {
epochs: 100,
learning_rate: 1e-3,
checkpoint_every: 10,
checkpoint_dir: "outputs/checkpoints".to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct BurnTrainMetrics {
pub epoch: usize,
pub loss: f64,
}
#[derive(Debug, Error)]
pub enum BurnTrainingError {
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("safetensors error: {0}")]
SafeTensors(#[from] safetensors::SafeTensorError),
#[error("model error: {0}")]
Model(String),
}
pub trait BurnToyModel<B: AutodiffBackend>: Sized {
fn name(&self) -> &'static str;
fn new(device: &B::Device) -> Self;
fn initialize(&mut self, _graph: &GraphFeatures, _device: &B::Device) -> Result<(), String>;
fn train_step(&mut self, _graph: &GraphFeatures, epoch: usize, _lr: f64) -> Result<f64, String>;
fn save_model_safetensors(&self, path: &Path) -> Result<(), String>;
fn load_model_safetensors(device: &B::Device, path: &Path) -> Result<Self, String>;
}
#[derive(Debug, Clone)]
pub struct BurnCheckpointPaths {
pub model_path: PathBuf,
pub trainer_state_path: PathBuf,
}
pub fn graph_contract_to_tensors<B: AutodiffBackend>(
graph: &GraphFeatures,
device: &B::Device,
) -> GraphTensors<B> {
GraphTensors::from_features(graph, device)
}
pub fn save_trainer_state_safetensors(
path: &Path,
epoch: usize,
loss: f64,
learning_rate: f64,
) -> Result<(), BurnTrainingError> {
let epoch_bytes = (epoch as i64).to_le_bytes();
let loss_bytes = loss.to_le_bytes();
let lr_bytes = learning_rate.to_le_bytes();
let epoch_view = TensorView::new(Dtype::I64, vec![1], &epoch_bytes)?;
let loss_view = TensorView::new(Dtype::F64, vec![1], &loss_bytes)?;
let lr_view = TensorView::new(Dtype::F64, vec![1], &lr_bytes)?;
let tensors = BTreeMap::from([
("epoch".to_string(), epoch_view),
("loss".to_string(), loss_view),
("learning_rate".to_string(), lr_view),
]);
serialize_to_file(tensors, &None, path)?;
Ok(())
}
pub fn save_checkpoint_safetensors<B: AutodiffBackend, M: BurnToyModel<B>>(
model: &M,
config: &BurnTrainConfig,
epoch: usize,
loss: f64,
) -> Result<BurnCheckpointPaths, BurnTrainingError> {
fs::create_dir_all(&config.checkpoint_dir)?;
let model_path = Path::new(&config.checkpoint_dir)
.join(format!("{}_epoch_{:04}.model.safetensors", model.name(), epoch));
let trainer_state_path = Path::new(&config.checkpoint_dir)
.join(format!("{}_epoch_{:04}.trainer.safetensors", model.name(), epoch));
model
.save_model_safetensors(&model_path)
.map_err(BurnTrainingError::Model)?;
save_trainer_state_safetensors(&trainer_state_path, epoch, loss, config.learning_rate)?;
Ok(BurnCheckpointPaths {
model_path,
trainer_state_path,
})
}
pub fn run_burn_training<B: AutodiffBackend, M: BurnToyModel<B>>(
model: &mut M,
graph: &GraphFeatures,
device: &B::Device,
config: &BurnTrainConfig,
) -> Result<Vec<BurnTrainMetrics>, BurnTrainingError> {
model
.initialize(graph, device)
.map_err(BurnTrainingError::Model)?;
let mut history = Vec::with_capacity(config.epochs);
for epoch in 1..=config.epochs {
let loss = model
.train_step(graph, epoch, config.learning_rate)
.map_err(BurnTrainingError::Model)?;
history.push(BurnTrainMetrics { epoch, loss });
if epoch % config.checkpoint_every == 0 || epoch == config.epochs {
let _ = save_checkpoint_safetensors::<B, M>(model, config, epoch, loss)?;
}
}
Ok(history)
}
}

56
src/burn_adapter.rs Normal file
View File

@@ -0,0 +1,56 @@
#[cfg(feature = "burn")]
use burn::tensor::{backend::Backend, Int, Tensor};
#[cfg(feature = "burn")]
use crate::geometry::PoseState;
#[cfg(feature = "burn")]
use crate::graph::GraphFeatures;
#[cfg(feature = "burn")]
pub struct GraphTensors<B: Backend> {
pub node_features: Tensor<B, 2>,
pub edge_index: Tensor<B, 2, Int>,
pub edge_features: Option<Tensor<B, 2>>,
pub torsion_edge_index: Tensor<B, 2, Int>,
}
#[cfg(feature = "burn")]
impl<B: Backend> GraphTensors<B> {
pub fn from_features(features: &GraphFeatures, device: &B::Device) -> Self {
Self {
node_features: Tensor::<B, 2>::zeros(
[features.num_nodes, features.spec.node_feat_dim],
device,
),
edge_index: Tensor::<B, 2, Int>::zeros([2, features.num_edges], device),
edge_features: Some(Tensor::<B, 2>::zeros(
[features.num_edges, features.spec.edge_feat_dim],
device,
)),
torsion_edge_index: Tensor::<B, 2, Int>::zeros(
[2, features.num_torsion_edges],
device,
),
}
}
}
#[cfg(feature = "burn")]
pub struct PoseTensors<B: Backend> {
pub translation: Tensor<B, 2>,
pub rotation: Tensor<B, 2>,
pub torsions: Tensor<B, 2>,
}
#[cfg(feature = "burn")]
impl<B: Backend> PoseTensors<B> {
pub fn from_pose_batch(poses: &[PoseState], device: &B::Device) -> Self {
let batch = poses.len();
let torsion_dim = poses.first().map(|p| p.torsions.len()).unwrap_or(0);
Self {
translation: Tensor::<B, 2>::zeros([batch, 3], device),
rotation: Tensor::<B, 2>::zeros([batch, 4], device),
torsions: Tensor::<B, 2>::zeros([batch, torsion_dim], device),
}
}
}

42
src/conformer.rs Normal file
View File

@@ -0,0 +1,42 @@
use thiserror::Error;
use crate::graph::{LigandGraph, TorsionEdge};
use crate::geometry::PoseState;
#[derive(Debug, Error)]
pub enum ConformerError {
#[error("coords length {coords_len} != atoms length {atoms_len}")]
CoordinateLengthMismatch { coords_len: usize, atoms_len: usize },
}
#[derive(Debug, Clone)]
pub struct LigandConformer {
pub graph: LigandGraph,
pub coords: Vec<[f32; 3]>,
}
impl LigandConformer {
pub fn new(graph: LigandGraph, coords: Vec<[f32; 3]>) -> Result<Self, ConformerError> {
if graph.atoms.len() != coords.len() {
return Err(ConformerError::CoordinateLengthMismatch {
coords_len: coords.len(),
atoms_len: graph.atoms.len(),
});
}
Ok(Self { graph, coords })
}
}
pub fn apply_torsions(
coords: Vec<[f32; 3]>,
_torsion_edges: &[TorsionEdge],
_torsions: &[f32],
) -> Vec<[f32; 3]> {
// TODO(user-step-2): translation/rotation/torsion 적용 테스트에서 구현.
coords
}
pub fn apply_pose(coords: Vec<[f32; 3]>, _pose: &PoseState) -> Vec<[f32; 3]> {
// TODO(user-step-2): global rotation + translation 구현.
coords
}

27
src/geometry.rs Normal file
View File

@@ -0,0 +1,27 @@
#[derive(Debug, Clone)]
pub struct PoseState {
pub translation: [f32; 3],
pub rotation_quat_xyzw: [f32; 4],
pub torsions: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct PoseDelta {
pub translation: [f32; 3],
pub rotation_tangent: [f32; 3],
pub torsions: Vec<f32>,
}
pub fn interpolate_pose(pose0: &PoseState, _pose1: &PoseState, _t: f32) -> PoseState {
// TODO(user-step-4): geodesic/interpolation 정책 확정 후 구현.
pose0.clone()
}
pub fn pose_delta(pose0: &PoseState, _pose1: &PoseState) -> PoseDelta {
// TODO(user-step-4): geodesic delta 정의 후 구현.
PoseDelta {
translation: pose0.translation,
rotation_tangent: [0.0, 0.0, 0.0],
torsions: vec![0.0; pose0.torsions.len()],
}
}

83
src/graph.rs Normal file
View File

@@ -0,0 +1,83 @@
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Atom {
pub index: usize,
pub atomic_number: u8,
pub formal_charge: i8,
pub is_aromatic: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Bond {
pub index: usize,
pub src: usize,
pub dst: usize,
pub bond_order: u8,
pub is_aromatic: bool,
pub is_rotatable: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TorsionEdge {
pub bond_index: usize,
pub atom_left: usize,
pub atom_right: usize,
pub rotating_side: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct LigandGraph {
pub atoms: Vec<Atom>,
pub bonds: Vec<Bond>,
pub edge_index: Vec<(usize, usize)>,
pub torsion_edges: Vec<TorsionEdge>,
}
impl LigandGraph {
pub fn new(atoms: Vec<Atom>, bonds: Vec<Bond>, torsion_edges: Vec<TorsionEdge>) -> Self {
let mut edge_index = Vec::with_capacity(bonds.len() * 2);
for b in &bonds {
edge_index.push((b.src, b.dst));
edge_index.push((b.dst, b.src));
}
Self {
atoms,
bonds,
edge_index,
torsion_edges,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct GraphFeatureSpec {
pub node_feat_dim: usize,
pub edge_feat_dim: usize,
}
impl Default for GraphFeatureSpec {
fn default() -> Self {
Self {
node_feat_dim: 4,
edge_feat_dim: 3,
}
}
}
#[derive(Debug, Clone)]
pub struct GraphFeatures {
pub num_nodes: usize,
pub num_edges: usize,
pub num_torsion_edges: usize,
pub spec: GraphFeatureSpec,
}
impl GraphFeatures {
pub fn from_graph(graph: &LigandGraph) -> Self {
Self {
num_nodes: graph.atoms.len(),
num_edges: graph.edge_index.len(),
num_torsion_edges: graph.torsion_edges.len(),
spec: GraphFeatureSpec::default(),
}
}
}

10
src/lib.rs Normal file
View File

@@ -0,0 +1,10 @@
pub mod api;
pub mod burn_adapter;
pub mod conformer;
pub mod geometry;
pub mod graph;
pub mod simulator;
pub mod viz;
#[cfg(test)]
mod tests;

29
src/simulator.rs Normal file
View File

@@ -0,0 +1,29 @@
use crate::conformer::{ConformerError, LigandConformer};
use crate::geometry::PoseState;
pub fn simulate_at_t(
reference: &LigandConformer,
pose_start: &PoseState,
pose_target: &PoseState,
t: f32,
) -> Result<LigandConformer, ConformerError> {
let _ = (pose_start, pose_target, t);
// TODO(user-step-4): geodesic/interpolation 정책 확정 후 구현.
LigandConformer::new(reference.graph.clone(), reference.coords.clone())
}
pub fn simulate_trajectory(
reference: &LigandConformer,
pose_start: &PoseState,
pose_target: &PoseState,
num_steps: usize,
) -> Result<Vec<LigandConformer>, ConformerError> {
let n = num_steps.max(2);
// TODO(user-step-7): epoch 단위 저장 포맷과 결합할 때 metadata(time/epoch) 추가.
(0..n)
.map(|i| {
let t = i as f32 / (n - 1) as f32;
simulate_at_t(reference, pose_start, pose_target, t)
})
.collect()
}

40
src/tests.rs Normal file
View File

@@ -0,0 +1,40 @@
#[cfg(test)]
mod tests {
use crate::{
api::{OneSampleToyBatch, OneSampleToyPipeline},
conformer::LigandConformer,
geometry::PoseState,
graph::{Atom, Bond, LigandGraph, TorsionEdge},
};
fn simple_reference() -> LigandConformer {
let atoms = vec![
Atom { index: 0, atomic_number: 6, formal_charge: 0, is_aromatic: false },
Atom { index: 1, atomic_number: 6, formal_charge: 0, is_aromatic: false },
Atom { index: 2, atomic_number: 8, formal_charge: 0, is_aromatic: false },
];
let bonds = vec![
Bond { index: 0, src: 0, dst: 1, bond_order: 1, is_aromatic: false, is_rotatable: true },
Bond { index: 1, src: 1, dst: 2, bond_order: 1, is_aromatic: false, is_rotatable: false },
];
let torsion = vec![TorsionEdge { bond_index: 0, atom_left: 0, atom_right: 1, rotating_side: vec![2] }];
let graph = LigandGraph::new(atoms, bonds, torsion);
LigandConformer::new(graph, vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.2, 0.0]]).unwrap()
}
#[test]
fn one_sample_pipeline_contract_smoke_test() {
let reference = simple_reference();
let p0 = PoseState { translation: [0.0, 0.0, 0.0], rotation_quat_xyzw: [0.0, 0.0, 0.0, 1.0], torsions: vec![0.0] };
let p1 = PoseState { translation: [0.5, -0.2, 1.0], rotation_quat_xyzw: [0.0, 0.0, 0.0, 1.0], torsions: vec![0.1] };
let pipeline = OneSampleToyPipeline::new(OneSampleToyBatch {
graph: reference.graph.clone(),
reference,
pose_start: p0,
pose_target: p1,
});
let contract = pipeline.graph_contract();
assert_eq!(contract.spec.node_feat_dim, 4);
assert_eq!(contract.num_nodes, 3);
}
}

177
src/viz.rs Normal file
View File

@@ -0,0 +1,177 @@
use std::{
fs::{self, File},
io::Write,
path::{Path, PathBuf},
process::Command,
};
use plotters::prelude::*;
use crate::{
conformer::LigandConformer,
geometry::PoseState,
simulator::simulate_trajectory,
};
fn frame_bounds(frames: &[LigandConformer]) -> Option<((f32, f32), (f32, f32))> {
let mut min_x = f32::INFINITY;
let mut max_x = f32::NEG_INFINITY;
let mut min_y = f32::INFINITY;
let mut max_y = f32::NEG_INFINITY;
for frame in frames {
for c in &frame.coords {
min_x = min_x.min(c[0]);
max_x = max_x.max(c[0]);
min_y = min_y.min(c[1]);
max_y = max_y.max(c[1]);
}
}
if min_x.is_finite() && max_x.is_finite() && min_y.is_finite() && max_y.is_finite() {
Some(((min_x, max_x), (min_y, max_y)))
} else {
None
}
}
fn draw_conformer_png<P: AsRef<Path>>(
conformer: &LigandConformer,
out: P,
x_range: (f32, f32),
y_range: (f32, f32),
) -> Result<(), Box<dyn std::error::Error>> {
let root = BitMapBackend::new(out.as_ref(), (960, 960)).into_drawing_area();
root.fill(&WHITE)?;
let mut chart = ChartBuilder::on(&root)
.margin(12)
.build_cartesian_2d(x_range.0..x_range.1, y_range.0..y_range.1)?;
chart
.configure_mesh()
.disable_mesh()
.x_labels(0)
.y_labels(0)
.draw()?;
for b in &conformer.graph.bonds {
let a = conformer.coords[b.src];
let c = conformer.coords[b.dst];
let color = if b.is_rotatable { &RED } else { &BLACK };
chart.draw_series(std::iter::once(PathElement::new(
vec![(a[0], a[1]), (c[0], c[1])],
color.stroke_width(3),
)))?;
}
chart.draw_series(
conformer
.coords
.iter()
.map(|c| Circle::new((c[0], c[1]), 5, BLUE.filled())),
)?;
root.present()?;
Ok(())
}
pub fn plot_static_png<P: AsRef<Path>>(
conformer: &LigandConformer,
out: P,
) -> Result<(), Box<dyn std::error::Error>> {
let single = [conformer.clone()];
let ((min_x, max_x), (min_y, max_y)) =
frame_bounds(&single).ok_or("failed to compute frame bounds")?;
let pad = 1.0f32;
draw_conformer_png(
conformer,
out,
(min_x - pad, max_x + pad),
(min_y - pad, max_y + pad),
)
}
pub fn plot_trajectory_png_frames<P: AsRef<Path>>(
frames: &[LigandConformer],
out_dir: P,
) -> Result<(), Box<dyn std::error::Error>> {
fs::create_dir_all(out_dir.as_ref())?;
let ((min_x, max_x), (min_y, max_y)) = frame_bounds(frames).ok_or("empty frame list")?;
let pad = 1.0f32;
let xr = (min_x - pad, max_x + pad);
let yr = (min_y - pad, max_y + pad);
for (idx, frame) in frames.iter().enumerate() {
let out = out_dir.as_ref().join(format!("frame_{idx:04}.png"));
draw_conformer_png(frame, out, xr, yr)?;
}
Ok(())
}
pub fn export_xyz_frames<P: AsRef<Path>>(
frames: &[LigandConformer],
out_dir: P,
) -> Result<(), std::io::Error> {
fs::create_dir_all(out_dir.as_ref())?;
for (idx, frame) in frames.iter().enumerate() {
let out = out_dir.as_ref().join(format!("frame_{idx:04}.xyz"));
let mut f = File::create(out)?;
writeln!(f, "{}", frame.coords.len())?;
writeln!(f, "riemann-flow-gnn frame {idx}")?;
for (atom, c) in frame.graph.atoms.iter().zip(&frame.coords) {
writeln!(f, "{} {:.6} {:.6} {:.6}", atom.atomic_number, c[0], c[1], c[2])?;
}
}
Ok(())
}
pub fn frames_to_mp4<P: AsRef<Path>>(
frames_dir: P,
output_mp4: P,
fps: u32,
) -> Result<bool, Box<dyn std::error::Error>> {
let input_pattern: PathBuf = frames_dir.as_ref().join("frame_%04d.png");
let status = match Command::new("ffmpeg")
.args([
"-y",
"-framerate",
&fps.to_string(),
"-i",
input_pattern.to_string_lossy().as_ref(),
"-pix_fmt",
"yuv420p",
output_mp4.as_ref().to_string_lossy().as_ref(),
])
.status()
{
Ok(status) => status,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(false),
Err(e) => return Err(Box::new(e)),
};
if !status.success() {
return Err("ffmpeg command failed".into());
}
Ok(true)
}
pub fn export_simulation_video<P: AsRef<Path>>(
reference: &LigandConformer,
pose_start: &PoseState,
pose_target: &PoseState,
num_steps: usize,
out_dir: P,
fps: u32,
) -> Result<Vec<LigandConformer>, Box<dyn std::error::Error>> {
fs::create_dir_all(out_dir.as_ref())?;
let frames_png_dir = out_dir.as_ref().join("frames_png");
let frames_xyz_dir = out_dir.as_ref().join("frames_xyz");
let video_path = out_dir.as_ref().join("simulation.mp4");
let frames = simulate_trajectory(reference, pose_start, pose_target, num_steps)?;
plot_trajectory_png_frames(&frames, &frames_png_dir)?;
export_xyz_frames(&frames, &frames_xyz_dir)?;
let _video_written = frames_to_mp4(&frames_png_dir, &video_path, fps)?;
Ok(frames)
}