Add initial implementation of Riemannian Flow GNN with essential data structures, simulation logic, and visualization utilities
This commit is contained in:
5022
Cargo.lock
generated
Normal file
5022
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
19
Cargo.toml
Normal file
19
Cargo.toml
Normal file
@@ -0,0 +1,19 @@
|
||||
[package]
|
||||
name = "riemann-flow-gnn"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
burn = ["dep:burn", "dep:safetensors"]
|
||||
|
||||
[dependencies]
|
||||
plotters = { version = "0.3", default-features = false, features = ["bitmap_backend", "bitmap_encoder"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
thiserror = "1"
|
||||
burn = { version = "0.14", optional = true, features = ["autodiff", "ndarray"] }
|
||||
safetensors = { version = "0.4", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
approx = "0.5"
|
||||
32
README.md
32
README.md
@@ -1,3 +1,33 @@
|
||||
# riemann-flow-gnn
|
||||
|
||||
Riemannian Flow matching with GNN implementation in rust/burn
|
||||
Riemannian Flow matching with GNN implementation in rust/burn
|
||||
|
||||
## Toy prototype skeleton (API-first)
|
||||
|
||||
이 저장소는 `R^3 x SO(3) x T^m` one-sample overfitting 실험을 위한 **API 계약 중심 스켈레톤**을 포함합니다.
|
||||
내부 수학/시뮬레이션 로직은 의도적으로 최소화되어 있으며, TODO 지점부터 실험 코드를 채우는 구조입니다.
|
||||
|
||||
### 실행
|
||||
|
||||
```bash
|
||||
cargo run --example toy_pipeline
|
||||
```
|
||||
|
||||
예제는 파이프라인 계약 확인용으로 동작합니다.
|
||||
|
||||
### 현재 포함 범위
|
||||
|
||||
- graph/domain struct (`Atom`, `Bond`, `TorsionEdge`, `LigandGraph`)
|
||||
- one-sample entry API (`api::OneSampleToyPipeline`)
|
||||
- graph/pose/simulator 최소 계약
|
||||
- Burn tensor shape 중심 adapter skeleton (`--features burn`)
|
||||
- TODO 기반 구현 포인트 주석
|
||||
|
||||
### 이후 단계 권장
|
||||
|
||||
1. random init sampling 테스트
|
||||
2. translation/rotation/torsion 단위 테스트 확장
|
||||
3. GNN 모델 정의 및 입력 feature 고정
|
||||
4. geodesic path/velocity supervision 연결
|
||||
5. 학습 루프/실행 스크립트 추가
|
||||
6. epoch별 trajectory 저장 자동화
|
||||
697
burn_manifold_graph_ai_handoff.md
Normal file
697
burn_manifold_graph_ai_handoff.md
Normal file
@@ -0,0 +1,697 @@
|
||||
# Burn 기반 Pose Manifold Graph 구현 AI 전달 문서
|
||||
|
||||
## 0. 문서 목적
|
||||
|
||||
이 문서는 **단일 ligand pose overfitting 실험**을 위해, AI 에이전트가 구현해야 할 범위를 명확히 정의한 handoff 문서이다.
|
||||
|
||||
구현 대상은 다음 세 가지다.
|
||||
|
||||
1. **Burn 규격에 맞는 manifold tensor 표현 및 graph 구현체**
|
||||
2. **graph 입력을 받아 시각화 가능한 구조 표현과 렌더링 파이프라인**
|
||||
3. **time parameter `t ∈ [0, 1]`에 따라 translation / rotation / torsion이 변하는 simulator**
|
||||
|
||||
이번 단계의 목표는 **연구용 완성품**이 아니라, 발표 및 초기 검증을 위한 **작동 가능한 prototype**이다.
|
||||
|
||||
---
|
||||
|
||||
## 1. 문제 설정
|
||||
|
||||
### 1.1 상태 공간
|
||||
|
||||
Ligand pose state는 아래 product space로 둔다.
|
||||
|
||||
- `translation`: `R^3`
|
||||
- `rotation`: `SO(3)`
|
||||
- `torsion`: `T^m` where `m = number of rotatable bonds`
|
||||
|
||||
즉 하나의 pose state는 개념적으로 아래와 같다.
|
||||
|
||||
```text
|
||||
x = (p, R, theta)
|
||||
|
||||
p : global translation, shape [3]
|
||||
R : global rotation
|
||||
theta : torsion angles for rotatable bonds, shape [m]
|
||||
```
|
||||
|
||||
### 1.2 그래프 입력
|
||||
|
||||
입력 ligand는 graph로 표현한다.
|
||||
|
||||
- node = atom
|
||||
- edge = bond
|
||||
- torsion-available edge = 회전 가능한 bond
|
||||
|
||||
이번 단계에서는 **protein / pocket context는 제외**한다.
|
||||
오직 ligand 단일 graph와 target pose만 다룬다.
|
||||
|
||||
### 1.3 실험 목적
|
||||
|
||||
- 주어진 ligand graph 하나에 대해
|
||||
- 임의의 초기 pose에서 출발해
|
||||
- 하나의 fixed target pose로 수렴하는 trajectory를 생성하고
|
||||
- 이를 시각화 가능하게 만든다.
|
||||
|
||||
학습 로직 전체를 이 문서에서 구현하라는 뜻은 아니다.
|
||||
이번 구현의 중심은 **data representation + simulator + visualization**이다.
|
||||
|
||||
---
|
||||
|
||||
## 2. 구현 범위 요약
|
||||
|
||||
AI 에이전트는 아래 항목을 구현한다.
|
||||
|
||||
### 필수 구현
|
||||
|
||||
- Burn-friendly tensor container
|
||||
- ligand graph data structure
|
||||
- torsion metadata extraction/representation
|
||||
- pose state container
|
||||
- explicit simulator for `t=0..1`
|
||||
- 3D/2D visualization export utility
|
||||
|
||||
### 이번 단계에서 제외
|
||||
|
||||
- full training loop
|
||||
- optimizer / scheduler
|
||||
- dataset loader for many molecules
|
||||
- protein-ligand interaction scoring
|
||||
- equivariant GNN model 자체
|
||||
|
||||
---
|
||||
|
||||
## 3. 설계 원칙
|
||||
|
||||
### 3.1 Burn 친화적 설계
|
||||
|
||||
구현체는 Burn 모델과 쉽게 연결될 수 있어야 한다.
|
||||
따라서 다음 원칙을 지킨다.
|
||||
|
||||
- **tensor payload는 Burn tensor로 쉽게 변환 가능**해야 한다.
|
||||
- 구조체는 geometry / graph / rendering 책임을 분리한다.
|
||||
- manifold-specific logic는 순수 수학 유틸로 분리한다.
|
||||
- 시각화는 모델과 독립적인 모듈로 둔다.
|
||||
|
||||
### 3.2 API 변동에 덜 민감한 설계
|
||||
|
||||
Burn 버전 변화에 덜 흔들리도록, 내부 표현과 backend 종속성을 너무 깊게 섞지 않는다.
|
||||
|
||||
권장 방식:
|
||||
|
||||
- domain object는 plain Rust struct로 유지
|
||||
- 필요 시 `to_tensor::<B: Backend>(&device)` 같은 adapter 제공
|
||||
- 학습용 Autodiff backend와 추론용 backend를 쉽게 교체할 수 있게 설계
|
||||
|
||||
### 3.3 발표용 prototype 우선
|
||||
|
||||
이번 구현은 성능 최적화보다 아래가 중요하다.
|
||||
|
||||
- 구조가 명확할 것
|
||||
- simulator가 deterministic하게 잘 동작할 것
|
||||
- geometry가 눈으로 드러날 것
|
||||
|
||||
---
|
||||
|
||||
## 4. 구현해야 할 핵심 데이터 구조
|
||||
|
||||
### 4.1 Atom / Bond / LigandGraph
|
||||
|
||||
#### Atom
|
||||
|
||||
최소 필드:
|
||||
|
||||
```rust
|
||||
pub struct Atom {
|
||||
pub index: usize,
|
||||
pub atomic_number: u8,
|
||||
pub formal_charge: i8,
|
||||
pub is_aromatic: bool,
|
||||
}
|
||||
```
|
||||
|
||||
#### Bond
|
||||
|
||||
최소 필드:
|
||||
|
||||
```rust
|
||||
pub struct Bond {
|
||||
pub index: usize,
|
||||
pub src: usize,
|
||||
pub dst: usize,
|
||||
pub bond_order: u8,
|
||||
pub is_aromatic: bool,
|
||||
pub is_rotatable: bool,
|
||||
}
|
||||
```
|
||||
|
||||
#### TorsionEdge
|
||||
|
||||
회전 가능한 bond에 대해, 어떤 원자 집합이 어느 쪽 fragment로 회전하는지 알아야 한다.
|
||||
이 metadata가 simulator에서 중요하다.
|
||||
|
||||
```rust
|
||||
pub struct TorsionEdge {
|
||||
pub bond_index: usize,
|
||||
pub atom_left: usize,
|
||||
pub atom_right: usize,
|
||||
pub rotating_side: Vec<usize>,
|
||||
}
|
||||
```
|
||||
|
||||
설명:
|
||||
|
||||
- `atom_left`, `atom_right`는 torsion axis를 이루는 bond의 두 원자
|
||||
- `rotating_side`는 회전 적용 대상 atom index 목록
|
||||
|
||||
#### LigandGraph
|
||||
|
||||
```rust
|
||||
pub struct LigandGraph {
|
||||
pub atoms: Vec<Atom>,
|
||||
pub bonds: Vec<Bond>,
|
||||
pub edge_index: Vec<(usize, usize)>,
|
||||
pub torsion_edges: Vec<TorsionEdge>,
|
||||
}
|
||||
```
|
||||
|
||||
요구사항:
|
||||
|
||||
- undirected bond graph를 기본으로 하되
|
||||
- GNN 입력용으로는 directed edge pair로 변환 가능해야 함
|
||||
- torsion edge list를 별도로 유지해야 함
|
||||
|
||||
---
|
||||
|
||||
### 4.2 PoseState
|
||||
|
||||
pose state는 ligand의 geometry 상태를 담는다.
|
||||
|
||||
#### 권장 표현
|
||||
|
||||
```rust
|
||||
pub struct PoseState {
|
||||
pub translation: [f32; 3],
|
||||
pub rotation_quat_xyzw: [f32; 4],
|
||||
pub torsions: Vec<f32>,
|
||||
}
|
||||
```
|
||||
|
||||
주의:
|
||||
|
||||
- quaternion은 normalize 보장 필요
|
||||
- torsion은 radian 단위
|
||||
- torsion 범위는 기본적으로 `[-pi, pi)` 로 유지
|
||||
|
||||
rotation은 내부적으로 quaternion으로 저장하되,
|
||||
필요 시 axis-angle / rotation matrix 변환 유틸을 제공한다.
|
||||
|
||||
---
|
||||
|
||||
### 4.3 LigandConformer
|
||||
|
||||
실제 원자 좌표를 가진 구조 표현.
|
||||
|
||||
```rust
|
||||
pub struct LigandConformer {
|
||||
pub graph: LigandGraph,
|
||||
pub coords: Vec<[f32; 3]>,
|
||||
}
|
||||
```
|
||||
|
||||
요구사항:
|
||||
|
||||
- `coords.len() == graph.atoms.len()` 이어야 함
|
||||
- reference conformer를 기준으로 pose transform 적용 가능해야 함
|
||||
|
||||
---
|
||||
|
||||
## 5. Burn에 연결 가능한 tensor adapter
|
||||
|
||||
AI 에이전트는 아래와 같은 방향의 adapter를 제공한다.
|
||||
|
||||
```rust
|
||||
pub struct GraphTensors<B: burn::tensor::backend::Backend> {
|
||||
pub node_features: burn::tensor::Tensor<B, 2>,
|
||||
pub edge_index: burn::tensor::Tensor<B, 2, burn::tensor::Int>,
|
||||
pub edge_features: Option<burn::tensor::Tensor<B, 2>>,
|
||||
pub torsion_edge_index: burn::tensor::Tensor<B, 2, burn::tensor::Int>,
|
||||
}
|
||||
```
|
||||
|
||||
목표는 정확한 최종 타입을 박제하는 것이 아니라, 아래 contract를 만족하는 것이다.
|
||||
|
||||
### contract
|
||||
|
||||
- node features: `[num_nodes, node_feat_dim]`
|
||||
- edge index: `[2, num_edges]`
|
||||
- torsion edge index: `[2, num_rotatable_bonds]` 또는 equivalent representation
|
||||
- pose tensor export: translation / quaternion / torsion을 batchable tensor로 변환 가능
|
||||
|
||||
### node feature 최소 구성
|
||||
|
||||
최초 버전은 아래 정도면 충분하다.
|
||||
|
||||
- atomic number
|
||||
- formal charge
|
||||
- aromatic flag
|
||||
- degree(optional)
|
||||
|
||||
### edge feature 최소 구성
|
||||
|
||||
- bond order
|
||||
- aromatic flag
|
||||
- rotatable flag
|
||||
|
||||
### pose tensor export
|
||||
|
||||
```rust
|
||||
pub struct PoseTensors<B: burn::tensor::backend::Backend> {
|
||||
pub translation: burn::tensor::Tensor<B, 2>, // [batch, 3]
|
||||
pub rotation: burn::tensor::Tensor<B, 2>, // [batch, 4] quaternion
|
||||
pub torsions: burn::tensor::Tensor<B, 2>, // [batch, m]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 구현해야 할 manifold math 유틸
|
||||
|
||||
이 모듈은 모델과 완전히 분리된 순수 수학 유틸이어야 한다.
|
||||
|
||||
파일 예시:
|
||||
|
||||
```text
|
||||
src/geometry/
|
||||
euclidean.rs
|
||||
so3.rs
|
||||
torus.rs
|
||||
pose.rs
|
||||
```
|
||||
|
||||
### 6.1 Translation (`R^3`)
|
||||
|
||||
필수 함수:
|
||||
|
||||
- `lerp_translation(p0, p1, t)`
|
||||
- `translation_delta(p0, p1)`
|
||||
|
||||
### 6.2 Rotation (`SO(3)`)
|
||||
|
||||
필수 함수:
|
||||
|
||||
- quaternion normalize
|
||||
- quaternion multiply / inverse
|
||||
- quaternion to rotation matrix
|
||||
- rotation matrix to quaternion (optional)
|
||||
- `slerp(q0, q1, t)`
|
||||
- `relative_rotation(q_from, q_to)`
|
||||
- `so3_log(q_rel) -> [f32; 3]`
|
||||
- `so3_exp(omega) -> quat`
|
||||
|
||||
설명:
|
||||
|
||||
- simulator에서는 `slerp` 기반 interpolation만 먼저 구현해도 충분
|
||||
- 이후 tangent vector supervision을 위해 `log/exp` utility 유지 권장
|
||||
|
||||
### 6.3 Torsion (`T^m`)
|
||||
|
||||
필수 함수:
|
||||
|
||||
- `wrap_pi(x)`
|
||||
- `wrap_vec(theta)`
|
||||
- `shortest_angle_delta(a, b)`
|
||||
- `interpolate_torsion(theta0, theta1, t)`
|
||||
|
||||
정의:
|
||||
|
||||
```text
|
||||
shortest_angle_delta(a, b) = wrap_pi(b - a)
|
||||
interpolate_torsion(theta0, theta1, t) = wrap(theta0 + t * shortest_delta)
|
||||
```
|
||||
|
||||
### 6.4 Product Pose Geometry
|
||||
|
||||
필수 함수:
|
||||
|
||||
- `interpolate_pose(pose0, pose1, t) -> PoseState`
|
||||
- `pose_delta(pose0, pose1) -> PoseDelta`
|
||||
|
||||
예상 구조:
|
||||
|
||||
```rust
|
||||
pub struct PoseDelta {
|
||||
pub translation: [f32; 3],
|
||||
pub rotation_tangent: [f32; 3],
|
||||
pub torsions: Vec<f32>,
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 좌표 복원 / simulator 요구사항
|
||||
|
||||
핵심 요구는 다음과 같다.
|
||||
|
||||
> `reference conformer + pose state` 를 받아, 시각화 가능한 현재 좌표를 만든다.
|
||||
|
||||
### 7.1 적용 순서
|
||||
|
||||
권장 적용 순서:
|
||||
|
||||
1. reference conformer 좌표 복사
|
||||
2. rotatable bond별 torsion 회전 적용
|
||||
3. global rotation 적용
|
||||
4. global translation 적용
|
||||
|
||||
### 7.2 torsion 회전 적용
|
||||
|
||||
각 `TorsionEdge`에 대해:
|
||||
|
||||
- axis = bond direction
|
||||
- pivot = bond axis 상의 원자 위치
|
||||
- rotating_side에 포함된 atom들만 회전
|
||||
- angle = 해당 torsion value
|
||||
|
||||
즉, torsion은 local internal transform이다.
|
||||
|
||||
### 7.3 global transform 적용
|
||||
|
||||
- rotation quaternion으로 전체 원자 좌표 회전
|
||||
- translation vector 더하기
|
||||
|
||||
### 7.4 simulator
|
||||
|
||||
입력:
|
||||
|
||||
- `reference conformer`
|
||||
- `pose_start`
|
||||
- `pose_target`
|
||||
- `t in [0,1]`
|
||||
|
||||
출력:
|
||||
|
||||
- `PoseState at t`
|
||||
- `LigandConformer at t`
|
||||
|
||||
즉,
|
||||
|
||||
```rust
|
||||
pub fn simulate_at_t(
|
||||
reference: &LigandConformer,
|
||||
pose_start: &PoseState,
|
||||
pose_target: &PoseState,
|
||||
t: f32,
|
||||
) -> LigandConformer
|
||||
```
|
||||
|
||||
형태의 API가 있어야 한다.
|
||||
|
||||
### 7.5 trajectory
|
||||
|
||||
추가로 아래 helper도 구현한다.
|
||||
|
||||
```rust
|
||||
pub fn simulate_trajectory(
|
||||
reference: &LigandConformer,
|
||||
pose_start: &PoseState,
|
||||
pose_target: &PoseState,
|
||||
num_steps: usize,
|
||||
) -> Vec<LigandConformer>
|
||||
```
|
||||
|
||||
- `t = 0..1` 균등 분할
|
||||
- endpoint 포함
|
||||
- 결과는 animation 또는 프레임 저장에 활용 가능
|
||||
|
||||
---
|
||||
|
||||
## 8. Visualization 요구사항
|
||||
|
||||
이번 단계에서 visualization은 매우 중요하다.
|
||||
발표용 데모에 직접 쓰일 수 있어야 한다.
|
||||
|
||||
### 필수 출력
|
||||
|
||||
- ligand graph connectivity 확인용 2D/3D static view
|
||||
- current pose 단일 프레임 렌더링
|
||||
- trajectory animation용 frame sequence export
|
||||
|
||||
### 최소 기능
|
||||
|
||||
#### A. static plot
|
||||
|
||||
입력:
|
||||
- atom coordinates
|
||||
- bond list
|
||||
|
||||
출력:
|
||||
- atoms = points
|
||||
- bonds = lines
|
||||
- rotatable bond는 다른 색상 또는 강조
|
||||
|
||||
#### B. trajectory frames
|
||||
|
||||
입력:
|
||||
- `Vec<LigandConformer>`
|
||||
|
||||
출력:
|
||||
- png sequence 또는 gif용 intermediate frames
|
||||
|
||||
#### C. optional: HTML viewer
|
||||
|
||||
가능하면 아래 중 하나 추가 검토:
|
||||
|
||||
- plotly 기반 HTML
|
||||
- three.js/json export용 포맷
|
||||
- simple `.xyz` / `.sdf` frame export
|
||||
|
||||
### 시각화 우선순위
|
||||
|
||||
1. 정적 PNG 출력
|
||||
2. trajectory frame sequence 출력
|
||||
3. 선택적으로 gif/mp4/HTML
|
||||
|
||||
주의:
|
||||
|
||||
- visualization 구현은 model과 완전히 분리
|
||||
- geometry correctness 검증이 목적이지 예쁜 UI가 목적은 아님
|
||||
|
||||
---
|
||||
|
||||
## 9. 권장 디렉터리 구조
|
||||
|
||||
```text
|
||||
src/
|
||||
graph/
|
||||
atom.rs
|
||||
bond.rs
|
||||
torsion.rs
|
||||
ligand_graph.rs
|
||||
featurize.rs
|
||||
|
||||
geometry/
|
||||
euclidean.rs
|
||||
so3.rs
|
||||
torus.rs
|
||||
pose.rs
|
||||
|
||||
conformer/
|
||||
conformer.rs
|
||||
apply_torsion.rs
|
||||
apply_pose.rs
|
||||
|
||||
burn_adapter/
|
||||
graph_tensors.rs
|
||||
pose_tensors.rs
|
||||
|
||||
simulator/
|
||||
interpolate.rs
|
||||
trajectory.rs
|
||||
|
||||
viz/
|
||||
plot_static.rs
|
||||
plot_trajectory.rs
|
||||
export.rs
|
||||
|
||||
lib.rs
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. 구현 상세 계약
|
||||
|
||||
### 10.1 반드시 보장해야 할 것
|
||||
|
||||
- quaternion normalization이 항상 유지될 것
|
||||
- torsion angle은 항상 wrapped range에 있을 것
|
||||
- `rotating_side` 계산이 일관될 것
|
||||
- bond axis 기준 회전이 올바를 것
|
||||
- `simulate_at_t(..., 0.0)` 는 시작 pose와 일치할 것
|
||||
- `simulate_at_t(..., 1.0)` 는 타깃 pose와 일치할 것
|
||||
|
||||
### 10.2 numerical sanity checks
|
||||
|
||||
AI 에이전트는 아래 테스트를 함께 작성한다.
|
||||
|
||||
#### Test 1: wrap correctness
|
||||
|
||||
- `pi - eps` 와 `-pi + eps` 의 차이가 작은 값으로 처리되는지 검증
|
||||
|
||||
#### Test 2: quaternion norm
|
||||
|
||||
- interpolation 전후 quaternion norm ≈ 1
|
||||
|
||||
#### Test 3: endpoint correctness
|
||||
|
||||
- trajectory 첫 프레임과 마지막 프레임이 입력 pose와 일치
|
||||
|
||||
#### Test 4: rigid transform invariance
|
||||
|
||||
- torsion 없는 molecule에서는 internal geometry가 rigid하게 유지
|
||||
|
||||
#### Test 5: torsion rotation locality
|
||||
|
||||
- 특정 torsion 회전 시 rotating_side atom만 변하는지 확인
|
||||
|
||||
---
|
||||
|
||||
## 11. 입력 데이터 가정
|
||||
|
||||
이번 구현은 범용 loader까지 강제하지 않는다.
|
||||
아래 두 입력 경로 중 하나만 우선 지원하면 된다.
|
||||
|
||||
### 경로 A: 수동 구성
|
||||
|
||||
- atom list
|
||||
- bond list
|
||||
- reference coordinates
|
||||
- rotatable bond metadata
|
||||
- target pose
|
||||
|
||||
### 경로 B: RDKit 등 외부 전처리 결과 import
|
||||
|
||||
- JSON 또는 serde-friendly 포맷으로 graph + conformer metadata load
|
||||
|
||||
권장:
|
||||
|
||||
- Rust 내부 구현은 외부 화학 toolkit에 강하게 묶지 말 것
|
||||
- 외부 전처리 결과를 import하는 구조가 더 안전함
|
||||
|
||||
---
|
||||
|
||||
## 12. 구현 우선순위
|
||||
|
||||
### Phase 1
|
||||
|
||||
- `Atom`, `Bond`, `TorsionEdge`, `LigandGraph`
|
||||
- `PoseState`
|
||||
- `wrap_pi`, `shortest_angle_delta`
|
||||
- quaternion normalize + slerp
|
||||
- `interpolate_pose`
|
||||
|
||||
### Phase 2
|
||||
|
||||
- `LigandConformer`
|
||||
- torsion application
|
||||
- global rotation/translation application
|
||||
- `simulate_at_t`
|
||||
|
||||
### Phase 3
|
||||
|
||||
- Burn tensor adapter
|
||||
- graph/node/edge feature export
|
||||
- pose tensor export
|
||||
|
||||
### Phase 4
|
||||
|
||||
- visualization static export
|
||||
- trajectory frame export
|
||||
- tests / sanity checks
|
||||
|
||||
---
|
||||
|
||||
## 13. 기대 산출물
|
||||
|
||||
AI 에이전트는 최종적으로 아래를 제공해야 한다.
|
||||
|
||||
### 코드 산출물
|
||||
|
||||
- Rust crate source
|
||||
- 기본 unit tests
|
||||
- example 실행 코드
|
||||
|
||||
### example 요구
|
||||
|
||||
example은 아래를 보여줘야 한다.
|
||||
|
||||
1. ligand graph 생성 또는 로드
|
||||
2. reference conformer 준비
|
||||
3. start pose / target pose 설정
|
||||
4. `t=0.0, 0.25, 0.5, 0.75, 1.0` 프레임 생성
|
||||
5. static image 또는 trajectory frame export
|
||||
6. graph tensor / pose tensor 변환 예시
|
||||
|
||||
---
|
||||
|
||||
## 14. AI 에이전트에게 주는 구현 지침
|
||||
|
||||
### 반드시 지킬 것
|
||||
|
||||
- 너무 많은 abstraction을 한 번에 넣지 말 것
|
||||
- 우선 **작동하는 명시적 구현**을 만들 것
|
||||
- geometry correctness를 최우선으로 둘 것
|
||||
- Burn 모델 학습 코드까지 과도하게 확장하지 말 것
|
||||
|
||||
### 선호하는 구현 스타일
|
||||
|
||||
- plain Rust struct 중심
|
||||
- 명시적 타입
|
||||
- unit test 풍부하게
|
||||
- 실패 시 panic보다 `Result` 기반 오류 처리 선호
|
||||
- visualization/output path는 예제에서 바로 실행 가능하게 구성
|
||||
|
||||
### 금지 사항
|
||||
|
||||
- 불필요한 거대 프레임워크 도입
|
||||
- protein context까지 무리하게 확장
|
||||
- 학습 코드와 시각화 코드를 강하게 결합
|
||||
- 화학 toolkit 종속성을 core domain object 안에 깊게 삽입
|
||||
|
||||
---
|
||||
|
||||
## 15. 구현 완료 판정 기준
|
||||
|
||||
아래 조건을 만족하면 이번 단계 구현 완료로 본다.
|
||||
|
||||
- ligand graph와 torsion metadata를 표현할 수 있음
|
||||
- pose state를 translation / rotation / torsion으로 표현할 수 있음
|
||||
- `t=0..1` explicit interpolation simulator가 동작함
|
||||
- conformer 좌표를 올바르게 복원함
|
||||
- static frame 및 trajectory frame을 export할 수 있음
|
||||
- Burn 입력용 tensor adapter가 존재함
|
||||
- 최소한의 numerical sanity tests가 모두 통과함
|
||||
|
||||
---
|
||||
|
||||
## 16. 선택적 확장 아이디어
|
||||
|
||||
이번 단계 이후 확장 후보:
|
||||
|
||||
- batch support
|
||||
- pocket context 추가
|
||||
- learned vector field model 연결
|
||||
- explicit path supervision dataset 생성기
|
||||
- SE(3)-equivariant / graph neural model 연결
|
||||
- target pose 단일 개체가 아니라 multiple conformer target 지원
|
||||
|
||||
---
|
||||
|
||||
## 17. 마지막 메모
|
||||
|
||||
이 구현은 최종 모델이 아니라, 아래 질문에 답하기 위한 prototype이다.
|
||||
|
||||
> “ligand pose manifold (`R^3 × SO(3) × T^m`) 위의 상태를 Rust/Burn 친화적으로 표현하고,
|
||||
> graph 입력과 연결하며,
|
||||
> explicit path simulator와 visualization까지 한 번에 돌릴 수 있는가?”
|
||||
|
||||
이 질문에 yes를 만들 수 있으면 이번 단계는 충분히 성공이다.
|
||||
343
data/toy/aspirin_reference.json
Normal file
343
data/toy/aspirin_reference.json
Normal file
@@ -0,0 +1,343 @@
|
||||
{
|
||||
"smiles": "CC(=O)OC1=CC=CC=C1C(O)=O",
|
||||
"generator": {
|
||||
"tool": "rdkit",
|
||||
"rdkit_version": "2025.09.4",
|
||||
"method": "ETKDGv3 + UFFOptimizeMolecule",
|
||||
"seed": 20260415
|
||||
},
|
||||
"atoms": [
|
||||
{
|
||||
"index": 0,
|
||||
"atomic_number": 6,
|
||||
"formal_charge": 0,
|
||||
"is_aromatic": false
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"atomic_number": 6,
|
||||
"formal_charge": 0,
|
||||
"is_aromatic": false
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"atomic_number": 8,
|
||||
"formal_charge": 0,
|
||||
"is_aromatic": false
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"atomic_number": 8,
|
||||
"formal_charge": 0,
|
||||
"is_aromatic": false
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"atomic_number": 6,
|
||||
"formal_charge": 0,
|
||||
"is_aromatic": true
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"atomic_number": 6,
|
||||
"formal_charge": 0,
|
||||
"is_aromatic": true
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"atomic_number": 6,
|
||||
"formal_charge": 0,
|
||||
"is_aromatic": true
|
||||
},
|
||||
{
|
||||
"index": 7,
|
||||
"atomic_number": 6,
|
||||
"formal_charge": 0,
|
||||
"is_aromatic": true
|
||||
},
|
||||
{
|
||||
"index": 8,
|
||||
"atomic_number": 6,
|
||||
"formal_charge": 0,
|
||||
"is_aromatic": true
|
||||
},
|
||||
{
|
||||
"index": 9,
|
||||
"atomic_number": 6,
|
||||
"formal_charge": 0,
|
||||
"is_aromatic": true
|
||||
},
|
||||
{
|
||||
"index": 10,
|
||||
"atomic_number": 6,
|
||||
"formal_charge": 0,
|
||||
"is_aromatic": false
|
||||
},
|
||||
{
|
||||
"index": 11,
|
||||
"atomic_number": 8,
|
||||
"formal_charge": 0,
|
||||
"is_aromatic": false
|
||||
},
|
||||
{
|
||||
"index": 12,
|
||||
"atomic_number": 8,
|
||||
"formal_charge": 0,
|
||||
"is_aromatic": false
|
||||
}
|
||||
],
|
||||
"bonds": [
|
||||
{
|
||||
"index": 0,
|
||||
"src": 0,
|
||||
"dst": 1,
|
||||
"bond_order": 1,
|
||||
"is_aromatic": false,
|
||||
"is_rotatable": false
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"src": 1,
|
||||
"dst": 2,
|
||||
"bond_order": 2,
|
||||
"is_aromatic": false,
|
||||
"is_rotatable": false
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"src": 1,
|
||||
"dst": 3,
|
||||
"bond_order": 1,
|
||||
"is_aromatic": false,
|
||||
"is_rotatable": true
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"src": 3,
|
||||
"dst": 4,
|
||||
"bond_order": 1,
|
||||
"is_aromatic": false,
|
||||
"is_rotatable": true
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"src": 4,
|
||||
"dst": 5,
|
||||
"bond_order": 2,
|
||||
"is_aromatic": true,
|
||||
"is_rotatable": false
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"src": 5,
|
||||
"dst": 6,
|
||||
"bond_order": 2,
|
||||
"is_aromatic": true,
|
||||
"is_rotatable": false
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"src": 6,
|
||||
"dst": 7,
|
||||
"bond_order": 2,
|
||||
"is_aromatic": true,
|
||||
"is_rotatable": false
|
||||
},
|
||||
{
|
||||
"index": 7,
|
||||
"src": 7,
|
||||
"dst": 8,
|
||||
"bond_order": 2,
|
||||
"is_aromatic": true,
|
||||
"is_rotatable": false
|
||||
},
|
||||
{
|
||||
"index": 8,
|
||||
"src": 8,
|
||||
"dst": 9,
|
||||
"bond_order": 2,
|
||||
"is_aromatic": true,
|
||||
"is_rotatable": false
|
||||
},
|
||||
{
|
||||
"index": 9,
|
||||
"src": 9,
|
||||
"dst": 10,
|
||||
"bond_order": 1,
|
||||
"is_aromatic": false,
|
||||
"is_rotatable": true
|
||||
},
|
||||
{
|
||||
"index": 10,
|
||||
"src": 10,
|
||||
"dst": 11,
|
||||
"bond_order": 1,
|
||||
"is_aromatic": false,
|
||||
"is_rotatable": false
|
||||
},
|
||||
{
|
||||
"index": 11,
|
||||
"src": 10,
|
||||
"dst": 12,
|
||||
"bond_order": 2,
|
||||
"is_aromatic": false,
|
||||
"is_rotatable": false
|
||||
},
|
||||
{
|
||||
"index": 12,
|
||||
"src": 9,
|
||||
"dst": 4,
|
||||
"bond_order": 2,
|
||||
"is_aromatic": true,
|
||||
"is_rotatable": false
|
||||
}
|
||||
],
|
||||
"coords_angstrom": [
|
||||
[
|
||||
-2.445027592718335,
|
||||
-1.67948133887539,
|
||||
-1.2311204827416857
|
||||
],
|
||||
[
|
||||
-2.3256540791911036,
|
||||
-0.33513666420602944,
|
||||
-0.5982962602008224
|
||||
],
|
||||
[
|
||||
-3.2501478181982857,
|
||||
0.5081276650303239,
|
||||
-0.7531554024479279
|
||||
],
|
||||
[
|
||||
-1.240796578897732,
|
||||
-0.07117406663982812,
|
||||
0.24766934621062192
|
||||
],
|
||||
[
|
||||
-0.8693310896814674,
|
||||
1.1975591923239997,
|
||||
0.7212683477007807
|
||||
],
|
||||
[
|
||||
-1.0718661569800187,
|
||||
2.348408762854174,
|
||||
-0.0659265674971829
|
||||
],
|
||||
[
|
||||
-0.6879381383599061,
|
||||
3.605496329692259,
|
||||
0.40104039015513443
|
||||
],
|
||||
[
|
||||
-0.08809733429840264,
|
||||
3.732148598191945,
|
||||
1.6516331751807383
|
||||
],
|
||||
[
|
||||
0.1383074740644653,
|
||||
2.6011946492493183,
|
||||
2.4373021116517832
|
||||
],
|
||||
[
|
||||
-0.2413518197061185,
|
||||
1.3212637178820958,
|
||||
1.9855525634937683
|
||||
],
|
||||
[
|
||||
0.016171886685239042,
|
||||
0.13903726167392894,
|
||||
2.8502637354849125
|
||||
],
|
||||
[
|
||||
-0.38040222522103195,
|
||||
-1.1365956849618335,
|
||||
2.4579933961830043
|
||||
],
|
||||
[
|
||||
0.6001216008188676,
|
||||
0.2801422824805318,
|
||||
3.9591263627209816
|
||||
]
|
||||
],
|
||||
"bond_lengths_angstrom": [
|
||||
{
|
||||
"bond_index": 0,
|
||||
"src": 0,
|
||||
"dst": 1,
|
||||
"length_angstrom": 1.4906304493999016
|
||||
},
|
||||
{
|
||||
"bond_index": 1,
|
||||
"src": 1,
|
||||
"dst": 2,
|
||||
"length_angstrom": 1.26085873767184
|
||||
},
|
||||
{
|
||||
"bond_index": 2,
|
||||
"src": 1,
|
||||
"dst": 3,
|
||||
"length_angstrom": 1.4008032895762923
|
||||
},
|
||||
{
|
||||
"bond_index": 3,
|
||||
"src": 3,
|
||||
"dst": 4,
|
||||
"length_angstrom": 1.4042673200969
|
||||
},
|
||||
{
|
||||
"bond_index": 4,
|
||||
"src": 4,
|
||||
"dst": 5,
|
||||
"length_angstrom": 1.4089538750394368
|
||||
},
|
||||
{
|
||||
"bond_index": 5,
|
||||
"src": 5,
|
||||
"dst": 6,
|
||||
"length_angstrom": 1.3948935492429244
|
||||
},
|
||||
{
|
||||
"bond_index": 6,
|
||||
"src": 6,
|
||||
"dst": 7,
|
||||
"length_angstrom": 1.3927785542904383
|
||||
},
|
||||
{
|
||||
"bond_index": 7,
|
||||
"src": 7,
|
||||
"dst": 8,
|
||||
"length_angstrom": 1.3955614101909997
|
||||
},
|
||||
{
|
||||
"bond_index": 8,
|
||||
"src": 8,
|
||||
"dst": 9,
|
||||
"length_angstrom": 1.4094119421510216
|
||||
},
|
||||
{
|
||||
"bond_index": 9,
|
||||
"src": 9,
|
||||
"dst": 10,
|
||||
"length_angstrom": 1.4871796340988404
|
||||
},
|
||||
{
|
||||
"bond_index": 10,
|
||||
"src": 10,
|
||||
"dst": 11,
|
||||
"length_angstrom": 1.3922594800801036
|
||||
},
|
||||
{
|
||||
"bond_index": 11,
|
||||
"src": 10,
|
||||
"dst": 12,
|
||||
"length_angstrom": 1.2611440130351712
|
||||
},
|
||||
{
|
||||
"bond_index": 12,
|
||||
"src": 9,
|
||||
"dst": 4,
|
||||
"length_angstrom": 1.417065754107979
|
||||
}
|
||||
],
|
||||
"torsion_edges": []
|
||||
}
|
||||
230
examples/toy_pipeline.rs
Normal file
230
examples/toy_pipeline.rs
Normal file
@@ -0,0 +1,230 @@
|
||||
#[cfg(feature = "burn")]
|
||||
use std::{collections::BTreeMap, marker::PhantomData, path::Path};
|
||||
use std::fs;
|
||||
|
||||
#[cfg(feature = "burn")]
|
||||
use burn::{backend::Autodiff, tensor::backend::AutodiffBackend};
|
||||
use riemann_flow_gnn::{
|
||||
api::{OneSampleToyBatch, OneSampleToyPipeline},
|
||||
conformer::LigandConformer,
|
||||
geometry::PoseState,
|
||||
graph::{Atom, Bond, LigandGraph},
|
||||
viz::{export_simulation_video, plot_static_png},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
#[cfg(feature = "burn")]
|
||||
use riemann_flow_gnn::api::burn_training::{
|
||||
run_burn_training, BurnToyModel, BurnTrainConfig,
|
||||
};
|
||||
#[cfg(feature = "burn")]
|
||||
use safetensors::{
|
||||
serialize_to_file,
|
||||
tensor::{Dtype, TensorView},
|
||||
};
|
||||
|
||||
#[cfg(feature = "burn")]
|
||||
type ExampleBackend = Autodiff<burn::backend::NdArray>;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ToyAtom {
|
||||
index: usize,
|
||||
atomic_number: u8,
|
||||
formal_charge: i8,
|
||||
is_aromatic: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ToyBond {
|
||||
index: usize,
|
||||
src: usize,
|
||||
dst: usize,
|
||||
bond_order: u8,
|
||||
is_aromatic: bool,
|
||||
is_rotatable: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ToyReferenceData {
|
||||
atoms: Vec<ToyAtom>,
|
||||
bonds: Vec<ToyBond>,
|
||||
coords_angstrom: Vec<[f32; 3]>,
|
||||
}
|
||||
|
||||
fn load_toy_reference(path: &str) -> Result<(LigandGraph, LigandConformer), Box<dyn std::error::Error>> {
|
||||
let raw = fs::read_to_string(path)?;
|
||||
let parsed: ToyReferenceData = serde_json::from_str(&raw)?;
|
||||
|
||||
let atoms = parsed
|
||||
.atoms
|
||||
.into_iter()
|
||||
.map(|a| Atom {
|
||||
index: a.index,
|
||||
atomic_number: a.atomic_number,
|
||||
formal_charge: a.formal_charge,
|
||||
is_aromatic: a.is_aromatic,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let bonds = parsed
|
||||
.bonds
|
||||
.into_iter()
|
||||
.map(|b| Bond {
|
||||
index: b.index,
|
||||
src: b.src,
|
||||
dst: b.dst,
|
||||
bond_order: b.bond_order,
|
||||
is_aromatic: b.is_aromatic,
|
||||
is_rotatable: b.is_rotatable,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let graph = LigandGraph::new(atoms, bonds, vec![]);
|
||||
let reference = LigandConformer::new(graph.clone(), parsed.coords_angstrom)?;
|
||||
Ok((graph, reference))
|
||||
}
|
||||
|
||||
#[cfg(feature = "burn")]
|
||||
struct ExampleBurnModel<B: AutodiffBackend> {
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "burn")]
|
||||
impl<B: AutodiffBackend> ExampleBurnModel<B> {
|
||||
fn forward_template(&self, _graph: &riemann_flow_gnn::graph::GraphFeatures) -> Result<(), String> {
|
||||
// TODO(user): Burn tensor 입력(GraphTensors, PoseTensors) 기반 forward 구현
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn loss_template(&self) -> Result<f64, String> {
|
||||
// TODO(user): translation/rotation/torsion loss 조합 구현
|
||||
Ok(1.0)
|
||||
}
|
||||
|
||||
fn optimizer_step_template(&mut self, _lr: f64) -> Result<(), String> {
|
||||
// TODO(user): Burn optimizer step 구현
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "burn")]
|
||||
impl<B: AutodiffBackend> BurnToyModel<B> for ExampleBurnModel<B> {
|
||||
fn name(&self) -> &'static str {
|
||||
"burn-toy-model"
|
||||
}
|
||||
|
||||
fn new(_device: &B::Device) -> Self {
|
||||
Self {
|
||||
_backend: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn initialize(
|
||||
&mut self,
|
||||
_graph: &riemann_flow_gnn::graph::GraphFeatures,
|
||||
_device: &B::Device,
|
||||
) -> Result<(), String> {
|
||||
// TODO(user): Burn 모델 파라미터 초기화
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn train_step(
|
||||
&mut self,
|
||||
graph: &riemann_flow_gnn::graph::GraphFeatures,
|
||||
epoch: usize,
|
||||
lr: f64,
|
||||
) -> Result<f64, String> {
|
||||
self.forward_template(graph)?;
|
||||
let base_loss = self.loss_template()?;
|
||||
self.optimizer_step_template(lr)?;
|
||||
Ok(base_loss / (epoch as f64))
|
||||
}
|
||||
|
||||
fn save_model_safetensors(&self, path: &Path) -> Result<(), String> {
|
||||
// TODO(user): 실제 model state tensor들을 safetensors로 저장
|
||||
let value_bytes = 1.0f32.to_le_bytes();
|
||||
let value_view =
|
||||
TensorView::new(Dtype::F32, vec![1], &value_bytes).map_err(|e| e.to_string())?;
|
||||
let tensors = BTreeMap::from([("placeholder_weight".to_string(), value_view)]);
|
||||
serialize_to_file(tensors, &None, path).map_err(|e| e.to_string())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_model_safetensors(_device: &B::Device, _path: &Path) -> Result<Self, String> {
|
||||
// TODO(user): safetensors에서 model state 복원
|
||||
Ok(Self {
|
||||
_backend: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (graph, reference) = load_toy_reference("data/toy/aspirin_reference.json")?;
|
||||
|
||||
let pose_start = PoseState {
|
||||
translation: [0.0, 0.0, 0.0],
|
||||
rotation_quat_xyzw: [0.0, 0.0, 0.0, 1.0],
|
||||
torsions: vec![0.0],
|
||||
};
|
||||
let pose_target = PoseState {
|
||||
translation: [1.0, 1.5, 0.5],
|
||||
rotation_quat_xyzw: [0.0, 0.0, 0.38268343, 0.9238795],
|
||||
torsions: vec![1.2],
|
||||
};
|
||||
|
||||
let pipeline = OneSampleToyPipeline::new(OneSampleToyBatch {
|
||||
graph,
|
||||
reference: reference.clone(),
|
||||
pose_start,
|
||||
pose_target,
|
||||
});
|
||||
let graph_features = pipeline.graph_contract();
|
||||
|
||||
#[cfg(feature = "burn")]
|
||||
let history = {
|
||||
let device = Default::default();
|
||||
let mut model = ExampleBurnModel::<ExampleBackend>::new(&device);
|
||||
let train_config = BurnTrainConfig {
|
||||
epochs: 20,
|
||||
learning_rate: 1e-3,
|
||||
checkpoint_every: 5,
|
||||
checkpoint_dir: "outputs/checkpoints".to_string(),
|
||||
};
|
||||
run_burn_training::<ExampleBackend, _>(
|
||||
&mut model,
|
||||
&graph_features,
|
||||
&device,
|
||||
&train_config,
|
||||
)?
|
||||
};
|
||||
|
||||
// 3) simulation video export
|
||||
let trajectory = export_simulation_video(
|
||||
&pipeline.batch.reference,
|
||||
&pipeline.batch.pose_start,
|
||||
&pipeline.batch.pose_target,
|
||||
48,
|
||||
"outputs",
|
||||
24,
|
||||
)?;
|
||||
plot_static_png(&trajectory[0], "outputs/static_start.png")?;
|
||||
plot_static_png(trajectory.last().ok_or("empty trajectory")?, "outputs/static_end.png")?;
|
||||
println!(
|
||||
"contract: nodes={}, edges={}, torsions={}, node_dim={}, edge_dim={}",
|
||||
graph_features.num_nodes,
|
||||
graph_features.num_edges,
|
||||
graph_features.num_torsion_edges,
|
||||
graph_features.spec.node_feat_dim,
|
||||
graph_features.spec.edge_feat_dim
|
||||
);
|
||||
#[cfg(feature = "burn")]
|
||||
println!(
|
||||
"burn training finished: epochs={}, last_loss={:.6}",
|
||||
history.len(),
|
||||
history.last().map(|m| m.loss).unwrap_or(0.0)
|
||||
);
|
||||
#[cfg(feature = "burn")]
|
||||
println!("saved checkpoints: outputs/checkpoints/*.safetensors");
|
||||
#[cfg(not(feature = "burn"))]
|
||||
println!("burn feature is off: training skeleton skipped");
|
||||
println!("saved video: outputs/simulation.mp4");
|
||||
println!("target reference: data/toy/aspirin_reference.json");
|
||||
Ok(())
|
||||
}
|
||||
194
src/api.rs
Normal file
194
src/api.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
use crate::{
|
||||
conformer::{ConformerError, LigandConformer},
|
||||
geometry::PoseState,
|
||||
graph::{GraphFeatures, LigandGraph},
|
||||
simulator::{simulate_at_t, simulate_trajectory},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OneSampleToyBatch {
|
||||
pub graph: LigandGraph,
|
||||
pub reference: LigandConformer,
|
||||
pub pose_start: PoseState,
|
||||
pub pose_target: PoseState,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OneSampleToyPipeline {
|
||||
pub batch: OneSampleToyBatch,
|
||||
}
|
||||
|
||||
impl OneSampleToyPipeline {
|
||||
pub fn new(batch: OneSampleToyBatch) -> Self {
|
||||
Self { batch }
|
||||
}
|
||||
|
||||
pub fn graph_contract(&self) -> GraphFeatures {
|
||||
GraphFeatures::from_graph(&self.batch.graph)
|
||||
}
|
||||
|
||||
pub fn sample_at_t(&self, t: f32) -> Result<LigandConformer, ConformerError> {
|
||||
simulate_at_t(
|
||||
&self.batch.reference,
|
||||
&self.batch.pose_start,
|
||||
&self.batch.pose_target,
|
||||
t,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn trajectory(&self, num_steps: usize) -> Result<Vec<LigandConformer>, ConformerError> {
|
||||
simulate_trajectory(
|
||||
&self.batch.reference,
|
||||
&self.batch.pose_start,
|
||||
&self.batch.pose_target,
|
||||
num_steps,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "burn")]
|
||||
pub mod burn_training {
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use burn::tensor::backend::AutodiffBackend;
|
||||
use safetensors::{
|
||||
serialize_to_file,
|
||||
tensor::{Dtype, TensorView},
|
||||
};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{burn_adapter::GraphTensors, graph::GraphFeatures};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BurnTrainConfig {
|
||||
pub epochs: usize,
|
||||
pub learning_rate: f64,
|
||||
pub checkpoint_every: usize,
|
||||
pub checkpoint_dir: String,
|
||||
}
|
||||
|
||||
impl Default for BurnTrainConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
epochs: 100,
|
||||
learning_rate: 1e-3,
|
||||
checkpoint_every: 10,
|
||||
checkpoint_dir: "outputs/checkpoints".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BurnTrainMetrics {
|
||||
pub epoch: usize,
|
||||
pub loss: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum BurnTrainingError {
|
||||
#[error("io error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("safetensors error: {0}")]
|
||||
SafeTensors(#[from] safetensors::SafeTensorError),
|
||||
#[error("model error: {0}")]
|
||||
Model(String),
|
||||
}
|
||||
|
||||
pub trait BurnToyModel<B: AutodiffBackend>: Sized {
|
||||
fn name(&self) -> &'static str;
|
||||
fn new(device: &B::Device) -> Self;
|
||||
fn initialize(&mut self, _graph: &GraphFeatures, _device: &B::Device) -> Result<(), String>;
|
||||
fn train_step(&mut self, _graph: &GraphFeatures, epoch: usize, _lr: f64) -> Result<f64, String>;
|
||||
fn save_model_safetensors(&self, path: &Path) -> Result<(), String>;
|
||||
fn load_model_safetensors(device: &B::Device, path: &Path) -> Result<Self, String>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BurnCheckpointPaths {
|
||||
pub model_path: PathBuf,
|
||||
pub trainer_state_path: PathBuf,
|
||||
}
|
||||
|
||||
pub fn graph_contract_to_tensors<B: AutodiffBackend>(
|
||||
graph: &GraphFeatures,
|
||||
device: &B::Device,
|
||||
) -> GraphTensors<B> {
|
||||
GraphTensors::from_features(graph, device)
|
||||
}
|
||||
|
||||
pub fn save_trainer_state_safetensors(
|
||||
path: &Path,
|
||||
epoch: usize,
|
||||
loss: f64,
|
||||
learning_rate: f64,
|
||||
) -> Result<(), BurnTrainingError> {
|
||||
let epoch_bytes = (epoch as i64).to_le_bytes();
|
||||
let loss_bytes = loss.to_le_bytes();
|
||||
let lr_bytes = learning_rate.to_le_bytes();
|
||||
|
||||
let epoch_view = TensorView::new(Dtype::I64, vec![1], &epoch_bytes)?;
|
||||
let loss_view = TensorView::new(Dtype::F64, vec![1], &loss_bytes)?;
|
||||
let lr_view = TensorView::new(Dtype::F64, vec![1], &lr_bytes)?;
|
||||
|
||||
let tensors = BTreeMap::from([
|
||||
("epoch".to_string(), epoch_view),
|
||||
("loss".to_string(), loss_view),
|
||||
("learning_rate".to_string(), lr_view),
|
||||
]);
|
||||
|
||||
serialize_to_file(tensors, &None, path)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn save_checkpoint_safetensors<B: AutodiffBackend, M: BurnToyModel<B>>(
|
||||
model: &M,
|
||||
config: &BurnTrainConfig,
|
||||
epoch: usize,
|
||||
loss: f64,
|
||||
) -> Result<BurnCheckpointPaths, BurnTrainingError> {
|
||||
fs::create_dir_all(&config.checkpoint_dir)?;
|
||||
|
||||
let model_path = Path::new(&config.checkpoint_dir)
|
||||
.join(format!("{}_epoch_{:04}.model.safetensors", model.name(), epoch));
|
||||
let trainer_state_path = Path::new(&config.checkpoint_dir)
|
||||
.join(format!("{}_epoch_{:04}.trainer.safetensors", model.name(), epoch));
|
||||
|
||||
model
|
||||
.save_model_safetensors(&model_path)
|
||||
.map_err(BurnTrainingError::Model)?;
|
||||
save_trainer_state_safetensors(&trainer_state_path, epoch, loss, config.learning_rate)?;
|
||||
|
||||
Ok(BurnCheckpointPaths {
|
||||
model_path,
|
||||
trainer_state_path,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run_burn_training<B: AutodiffBackend, M: BurnToyModel<B>>(
|
||||
model: &mut M,
|
||||
graph: &GraphFeatures,
|
||||
device: &B::Device,
|
||||
config: &BurnTrainConfig,
|
||||
) -> Result<Vec<BurnTrainMetrics>, BurnTrainingError> {
|
||||
model
|
||||
.initialize(graph, device)
|
||||
.map_err(BurnTrainingError::Model)?;
|
||||
|
||||
let mut history = Vec::with_capacity(config.epochs);
|
||||
for epoch in 1..=config.epochs {
|
||||
let loss = model
|
||||
.train_step(graph, epoch, config.learning_rate)
|
||||
.map_err(BurnTrainingError::Model)?;
|
||||
history.push(BurnTrainMetrics { epoch, loss });
|
||||
|
||||
if epoch % config.checkpoint_every == 0 || epoch == config.epochs {
|
||||
let _ = save_checkpoint_safetensors::<B, M>(model, config, epoch, loss)?;
|
||||
}
|
||||
}
|
||||
Ok(history)
|
||||
}
|
||||
}
|
||||
56
src/burn_adapter.rs
Normal file
56
src/burn_adapter.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
#[cfg(feature = "burn")]
|
||||
use burn::tensor::{backend::Backend, Int, Tensor};
|
||||
|
||||
#[cfg(feature = "burn")]
|
||||
use crate::geometry::PoseState;
|
||||
#[cfg(feature = "burn")]
|
||||
use crate::graph::GraphFeatures;
|
||||
|
||||
#[cfg(feature = "burn")]
|
||||
pub struct GraphTensors<B: Backend> {
|
||||
pub node_features: Tensor<B, 2>,
|
||||
pub edge_index: Tensor<B, 2, Int>,
|
||||
pub edge_features: Option<Tensor<B, 2>>,
|
||||
pub torsion_edge_index: Tensor<B, 2, Int>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "burn")]
|
||||
impl<B: Backend> GraphTensors<B> {
|
||||
pub fn from_features(features: &GraphFeatures, device: &B::Device) -> Self {
|
||||
Self {
|
||||
node_features: Tensor::<B, 2>::zeros(
|
||||
[features.num_nodes, features.spec.node_feat_dim],
|
||||
device,
|
||||
),
|
||||
edge_index: Tensor::<B, 2, Int>::zeros([2, features.num_edges], device),
|
||||
edge_features: Some(Tensor::<B, 2>::zeros(
|
||||
[features.num_edges, features.spec.edge_feat_dim],
|
||||
device,
|
||||
)),
|
||||
torsion_edge_index: Tensor::<B, 2, Int>::zeros(
|
||||
[2, features.num_torsion_edges],
|
||||
device,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "burn")]
|
||||
pub struct PoseTensors<B: Backend> {
|
||||
pub translation: Tensor<B, 2>,
|
||||
pub rotation: Tensor<B, 2>,
|
||||
pub torsions: Tensor<B, 2>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "burn")]
|
||||
impl<B: Backend> PoseTensors<B> {
|
||||
pub fn from_pose_batch(poses: &[PoseState], device: &B::Device) -> Self {
|
||||
let batch = poses.len();
|
||||
let torsion_dim = poses.first().map(|p| p.torsions.len()).unwrap_or(0);
|
||||
Self {
|
||||
translation: Tensor::<B, 2>::zeros([batch, 3], device),
|
||||
rotation: Tensor::<B, 2>::zeros([batch, 4], device),
|
||||
torsions: Tensor::<B, 2>::zeros([batch, torsion_dim], device),
|
||||
}
|
||||
}
|
||||
}
|
||||
42
src/conformer.rs
Normal file
42
src/conformer.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::graph::{LigandGraph, TorsionEdge};
|
||||
use crate::geometry::PoseState;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConformerError {
|
||||
#[error("coords length {coords_len} != atoms length {atoms_len}")]
|
||||
CoordinateLengthMismatch { coords_len: usize, atoms_len: usize },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LigandConformer {
|
||||
pub graph: LigandGraph,
|
||||
pub coords: Vec<[f32; 3]>,
|
||||
}
|
||||
|
||||
impl LigandConformer {
|
||||
pub fn new(graph: LigandGraph, coords: Vec<[f32; 3]>) -> Result<Self, ConformerError> {
|
||||
if graph.atoms.len() != coords.len() {
|
||||
return Err(ConformerError::CoordinateLengthMismatch {
|
||||
coords_len: coords.len(),
|
||||
atoms_len: graph.atoms.len(),
|
||||
});
|
||||
}
|
||||
Ok(Self { graph, coords })
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply_torsions(
|
||||
coords: Vec<[f32; 3]>,
|
||||
_torsion_edges: &[TorsionEdge],
|
||||
_torsions: &[f32],
|
||||
) -> Vec<[f32; 3]> {
|
||||
// TODO(user-step-2): translation/rotation/torsion 적용 테스트에서 구현.
|
||||
coords
|
||||
}
|
||||
|
||||
pub fn apply_pose(coords: Vec<[f32; 3]>, _pose: &PoseState) -> Vec<[f32; 3]> {
|
||||
// TODO(user-step-2): global rotation + translation 구현.
|
||||
coords
|
||||
}
|
||||
27
src/geometry.rs
Normal file
27
src/geometry.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PoseState {
|
||||
pub translation: [f32; 3],
|
||||
pub rotation_quat_xyzw: [f32; 4],
|
||||
pub torsions: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PoseDelta {
|
||||
pub translation: [f32; 3],
|
||||
pub rotation_tangent: [f32; 3],
|
||||
pub torsions: Vec<f32>,
|
||||
}
|
||||
|
||||
pub fn interpolate_pose(pose0: &PoseState, _pose1: &PoseState, _t: f32) -> PoseState {
|
||||
// TODO(user-step-4): geodesic/interpolation 정책 확정 후 구현.
|
||||
pose0.clone()
|
||||
}
|
||||
|
||||
pub fn pose_delta(pose0: &PoseState, _pose1: &PoseState) -> PoseDelta {
|
||||
// TODO(user-step-4): geodesic delta 정의 후 구현.
|
||||
PoseDelta {
|
||||
translation: pose0.translation,
|
||||
rotation_tangent: [0.0, 0.0, 0.0],
|
||||
torsions: vec![0.0; pose0.torsions.len()],
|
||||
}
|
||||
}
|
||||
83
src/graph.rs
Normal file
83
src/graph.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Atom {
|
||||
pub index: usize,
|
||||
pub atomic_number: u8,
|
||||
pub formal_charge: i8,
|
||||
pub is_aromatic: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Bond {
|
||||
pub index: usize,
|
||||
pub src: usize,
|
||||
pub dst: usize,
|
||||
pub bond_order: u8,
|
||||
pub is_aromatic: bool,
|
||||
pub is_rotatable: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct TorsionEdge {
|
||||
pub bond_index: usize,
|
||||
pub atom_left: usize,
|
||||
pub atom_right: usize,
|
||||
pub rotating_side: Vec<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LigandGraph {
|
||||
pub atoms: Vec<Atom>,
|
||||
pub bonds: Vec<Bond>,
|
||||
pub edge_index: Vec<(usize, usize)>,
|
||||
pub torsion_edges: Vec<TorsionEdge>,
|
||||
}
|
||||
|
||||
impl LigandGraph {
|
||||
pub fn new(atoms: Vec<Atom>, bonds: Vec<Bond>, torsion_edges: Vec<TorsionEdge>) -> Self {
|
||||
let mut edge_index = Vec::with_capacity(bonds.len() * 2);
|
||||
for b in &bonds {
|
||||
edge_index.push((b.src, b.dst));
|
||||
edge_index.push((b.dst, b.src));
|
||||
}
|
||||
Self {
|
||||
atoms,
|
||||
bonds,
|
||||
edge_index,
|
||||
torsion_edges,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct GraphFeatureSpec {
|
||||
pub node_feat_dim: usize,
|
||||
pub edge_feat_dim: usize,
|
||||
}
|
||||
|
||||
impl Default for GraphFeatureSpec {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
node_feat_dim: 4,
|
||||
edge_feat_dim: 3,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GraphFeatures {
|
||||
pub num_nodes: usize,
|
||||
pub num_edges: usize,
|
||||
pub num_torsion_edges: usize,
|
||||
pub spec: GraphFeatureSpec,
|
||||
}
|
||||
|
||||
impl GraphFeatures {
|
||||
pub fn from_graph(graph: &LigandGraph) -> Self {
|
||||
Self {
|
||||
num_nodes: graph.atoms.len(),
|
||||
num_edges: graph.edge_index.len(),
|
||||
num_torsion_edges: graph.torsion_edges.len(),
|
||||
spec: GraphFeatureSpec::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
10
src/lib.rs
Normal file
10
src/lib.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
pub mod api;
|
||||
pub mod burn_adapter;
|
||||
pub mod conformer;
|
||||
pub mod geometry;
|
||||
pub mod graph;
|
||||
pub mod simulator;
|
||||
pub mod viz;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
29
src/simulator.rs
Normal file
29
src/simulator.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use crate::conformer::{ConformerError, LigandConformer};
|
||||
use crate::geometry::PoseState;
|
||||
|
||||
pub fn simulate_at_t(
|
||||
reference: &LigandConformer,
|
||||
pose_start: &PoseState,
|
||||
pose_target: &PoseState,
|
||||
t: f32,
|
||||
) -> Result<LigandConformer, ConformerError> {
|
||||
let _ = (pose_start, pose_target, t);
|
||||
// TODO(user-step-4): geodesic/interpolation 정책 확정 후 구현.
|
||||
LigandConformer::new(reference.graph.clone(), reference.coords.clone())
|
||||
}
|
||||
|
||||
pub fn simulate_trajectory(
|
||||
reference: &LigandConformer,
|
||||
pose_start: &PoseState,
|
||||
pose_target: &PoseState,
|
||||
num_steps: usize,
|
||||
) -> Result<Vec<LigandConformer>, ConformerError> {
|
||||
let n = num_steps.max(2);
|
||||
// TODO(user-step-7): epoch 단위 저장 포맷과 결합할 때 metadata(time/epoch) 추가.
|
||||
(0..n)
|
||||
.map(|i| {
|
||||
let t = i as f32 / (n - 1) as f32;
|
||||
simulate_at_t(reference, pose_start, pose_target, t)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
40
src/tests.rs
Normal file
40
src/tests.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
api::{OneSampleToyBatch, OneSampleToyPipeline},
|
||||
conformer::LigandConformer,
|
||||
geometry::PoseState,
|
||||
graph::{Atom, Bond, LigandGraph, TorsionEdge},
|
||||
};
|
||||
|
||||
fn simple_reference() -> LigandConformer {
|
||||
let atoms = vec![
|
||||
Atom { index: 0, atomic_number: 6, formal_charge: 0, is_aromatic: false },
|
||||
Atom { index: 1, atomic_number: 6, formal_charge: 0, is_aromatic: false },
|
||||
Atom { index: 2, atomic_number: 8, formal_charge: 0, is_aromatic: false },
|
||||
];
|
||||
let bonds = vec![
|
||||
Bond { index: 0, src: 0, dst: 1, bond_order: 1, is_aromatic: false, is_rotatable: true },
|
||||
Bond { index: 1, src: 1, dst: 2, bond_order: 1, is_aromatic: false, is_rotatable: false },
|
||||
];
|
||||
let torsion = vec![TorsionEdge { bond_index: 0, atom_left: 0, atom_right: 1, rotating_side: vec![2] }];
|
||||
let graph = LigandGraph::new(atoms, bonds, torsion);
|
||||
LigandConformer::new(graph, vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.2, 0.0]]).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_sample_pipeline_contract_smoke_test() {
|
||||
let reference = simple_reference();
|
||||
let p0 = PoseState { translation: [0.0, 0.0, 0.0], rotation_quat_xyzw: [0.0, 0.0, 0.0, 1.0], torsions: vec![0.0] };
|
||||
let p1 = PoseState { translation: [0.5, -0.2, 1.0], rotation_quat_xyzw: [0.0, 0.0, 0.0, 1.0], torsions: vec![0.1] };
|
||||
let pipeline = OneSampleToyPipeline::new(OneSampleToyBatch {
|
||||
graph: reference.graph.clone(),
|
||||
reference,
|
||||
pose_start: p0,
|
||||
pose_target: p1,
|
||||
});
|
||||
let contract = pipeline.graph_contract();
|
||||
assert_eq!(contract.spec.node_feat_dim, 4);
|
||||
assert_eq!(contract.num_nodes, 3);
|
||||
}
|
||||
}
|
||||
177
src/viz.rs
Normal file
177
src/viz.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
use std::{
|
||||
fs::{self, File},
|
||||
io::Write,
|
||||
path::{Path, PathBuf},
|
||||
process::Command,
|
||||
};
|
||||
|
||||
use plotters::prelude::*;
|
||||
|
||||
use crate::{
|
||||
conformer::LigandConformer,
|
||||
geometry::PoseState,
|
||||
simulator::simulate_trajectory,
|
||||
};
|
||||
|
||||
fn frame_bounds(frames: &[LigandConformer]) -> Option<((f32, f32), (f32, f32))> {
|
||||
let mut min_x = f32::INFINITY;
|
||||
let mut max_x = f32::NEG_INFINITY;
|
||||
let mut min_y = f32::INFINITY;
|
||||
let mut max_y = f32::NEG_INFINITY;
|
||||
|
||||
for frame in frames {
|
||||
for c in &frame.coords {
|
||||
min_x = min_x.min(c[0]);
|
||||
max_x = max_x.max(c[0]);
|
||||
min_y = min_y.min(c[1]);
|
||||
max_y = max_y.max(c[1]);
|
||||
}
|
||||
}
|
||||
|
||||
if min_x.is_finite() && max_x.is_finite() && min_y.is_finite() && max_y.is_finite() {
|
||||
Some(((min_x, max_x), (min_y, max_y)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn draw_conformer_png<P: AsRef<Path>>(
|
||||
conformer: &LigandConformer,
|
||||
out: P,
|
||||
x_range: (f32, f32),
|
||||
y_range: (f32, f32),
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let root = BitMapBackend::new(out.as_ref(), (960, 960)).into_drawing_area();
|
||||
root.fill(&WHITE)?;
|
||||
|
||||
let mut chart = ChartBuilder::on(&root)
|
||||
.margin(12)
|
||||
.build_cartesian_2d(x_range.0..x_range.1, y_range.0..y_range.1)?;
|
||||
|
||||
chart
|
||||
.configure_mesh()
|
||||
.disable_mesh()
|
||||
.x_labels(0)
|
||||
.y_labels(0)
|
||||
.draw()?;
|
||||
|
||||
for b in &conformer.graph.bonds {
|
||||
let a = conformer.coords[b.src];
|
||||
let c = conformer.coords[b.dst];
|
||||
let color = if b.is_rotatable { &RED } else { &BLACK };
|
||||
chart.draw_series(std::iter::once(PathElement::new(
|
||||
vec![(a[0], a[1]), (c[0], c[1])],
|
||||
color.stroke_width(3),
|
||||
)))?;
|
||||
}
|
||||
|
||||
chart.draw_series(
|
||||
conformer
|
||||
.coords
|
||||
.iter()
|
||||
.map(|c| Circle::new((c[0], c[1]), 5, BLUE.filled())),
|
||||
)?;
|
||||
|
||||
root.present()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn plot_static_png<P: AsRef<Path>>(
|
||||
conformer: &LigandConformer,
|
||||
out: P,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let single = [conformer.clone()];
|
||||
let ((min_x, max_x), (min_y, max_y)) =
|
||||
frame_bounds(&single).ok_or("failed to compute frame bounds")?;
|
||||
let pad = 1.0f32;
|
||||
draw_conformer_png(
|
||||
conformer,
|
||||
out,
|
||||
(min_x - pad, max_x + pad),
|
||||
(min_y - pad, max_y + pad),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn plot_trajectory_png_frames<P: AsRef<Path>>(
|
||||
frames: &[LigandConformer],
|
||||
out_dir: P,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
fs::create_dir_all(out_dir.as_ref())?;
|
||||
let ((min_x, max_x), (min_y, max_y)) = frame_bounds(frames).ok_or("empty frame list")?;
|
||||
let pad = 1.0f32;
|
||||
let xr = (min_x - pad, max_x + pad);
|
||||
let yr = (min_y - pad, max_y + pad);
|
||||
|
||||
for (idx, frame) in frames.iter().enumerate() {
|
||||
let out = out_dir.as_ref().join(format!("frame_{idx:04}.png"));
|
||||
draw_conformer_png(frame, out, xr, yr)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn export_xyz_frames<P: AsRef<Path>>(
|
||||
frames: &[LigandConformer],
|
||||
out_dir: P,
|
||||
) -> Result<(), std::io::Error> {
|
||||
fs::create_dir_all(out_dir.as_ref())?;
|
||||
for (idx, frame) in frames.iter().enumerate() {
|
||||
let out = out_dir.as_ref().join(format!("frame_{idx:04}.xyz"));
|
||||
let mut f = File::create(out)?;
|
||||
writeln!(f, "{}", frame.coords.len())?;
|
||||
writeln!(f, "riemann-flow-gnn frame {idx}")?;
|
||||
for (atom, c) in frame.graph.atoms.iter().zip(&frame.coords) {
|
||||
writeln!(f, "{} {:.6} {:.6} {:.6}", atom.atomic_number, c[0], c[1], c[2])?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn frames_to_mp4<P: AsRef<Path>>(
|
||||
frames_dir: P,
|
||||
output_mp4: P,
|
||||
fps: u32,
|
||||
) -> Result<bool, Box<dyn std::error::Error>> {
|
||||
let input_pattern: PathBuf = frames_dir.as_ref().join("frame_%04d.png");
|
||||
let status = match Command::new("ffmpeg")
|
||||
.args([
|
||||
"-y",
|
||||
"-framerate",
|
||||
&fps.to_string(),
|
||||
"-i",
|
||||
input_pattern.to_string_lossy().as_ref(),
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
output_mp4.as_ref().to_string_lossy().as_ref(),
|
||||
])
|
||||
.status()
|
||||
{
|
||||
Ok(status) => status,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(false),
|
||||
Err(e) => return Err(Box::new(e)),
|
||||
};
|
||||
|
||||
if !status.success() {
|
||||
return Err("ffmpeg command failed".into());
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
pub fn export_simulation_video<P: AsRef<Path>>(
|
||||
reference: &LigandConformer,
|
||||
pose_start: &PoseState,
|
||||
pose_target: &PoseState,
|
||||
num_steps: usize,
|
||||
out_dir: P,
|
||||
fps: u32,
|
||||
) -> Result<Vec<LigandConformer>, Box<dyn std::error::Error>> {
|
||||
fs::create_dir_all(out_dir.as_ref())?;
|
||||
let frames_png_dir = out_dir.as_ref().join("frames_png");
|
||||
let frames_xyz_dir = out_dir.as_ref().join("frames_xyz");
|
||||
let video_path = out_dir.as_ref().join("simulation.mp4");
|
||||
|
||||
let frames = simulate_trajectory(reference, pose_start, pose_target, num_steps)?;
|
||||
plot_trajectory_png_frames(&frames, &frames_png_dir)?;
|
||||
export_xyz_frames(&frames, &frames_xyz_dir)?;
|
||||
let _video_written = frames_to_mp4(&frames_png_dir, &video_path, fps)?;
|
||||
Ok(frames)
|
||||
}
|
||||
Reference in New Issue
Block a user