mirror of
https://github.com/tracel-ai/burn.git
synced 2026-05-31 19:49:48 +09:00
* Initial commit * Improve compilation time * Refactor device * Device blob * Improvements * Autodiff fix * Burn-std no cubecl dep * WIP * Autodiff * Sanitize outside of generic * Fix display compilation * Clippy * Improve device selection * WIP * Fix examples wip * Fix fmt * Cleanup * Fmt * Fix clippy * Fix autodiff * Fix test device * Fix docs * Fix tests * Fmt * Fix display * Miri fix * Improve comments * Fix device * Fix docs
Wasserstein Generative Adversarial Network
A burn implementation of an example WGAN model to generate MNIST digits inspired by the PyTorch implementation. Please note that better performance maybe gained by adopting a convolution layer in some other models.
Usage
Training
# Cuda backend
cargo run --example wgan-mnist --release --features cuda
# Wgpu backend
cargo run --example wgan-mnist --release --features wgpu
# Tch GPU backend
export TORCH_CUDA_VERSION=cu128 # Set the cuda version
cargo run --example wgan-mnist --release --features tch-gpu
# Tch CPU backend
cargo run --example wgan-mnist --release --features tch-cpu
# Flex backend (CPU)
cargo run --example wgan-mnist --release --features flex # f32
Generating
To generate a sample of images, you can use wgan-generate. The same feature flags are used to select a backend.
cargo run --example wgan-generate --release --features cuda