Files
Nathaniel Simard 8037309704 perf(core)!: improve compile times via opaque inner types to break dependency chain (#4977)
* Initial commit

* Improve compilation time

* Refactor device

* Device blob

* Improvements

* Autodiff fix

* Burn-std no cubecl dep

* WIP

* Autodiff

* Sanitize outside of generic

* Fix display compilation

* Clippy

* Improve device selection

* WIP

* Fix examples wip

* Fix fmt

* Cleanup

* Fmt

* Fix clippy

* Fix autodiff

* Fix test device

* Fix docs

* Fix tests

* Fmt

* Fix display

* Miri fix

* Improve comments

* Fix device

* Fix docs
2026-05-21 14:08:30 -04:00
..
2026-05-14 14:02:44 -04:00

Wasserstein Generative Adversarial Network

A burn implementation of an example WGAN model to generate MNIST digits inspired by the PyTorch implementation. Please note that better performance maybe gained by adopting a convolution layer in some other models.

Usage

Training

# Cuda backend
cargo run --example wgan-mnist --release --features cuda

# Wgpu backend
cargo run --example wgan-mnist --release --features wgpu

# Tch GPU backend
export TORCH_CUDA_VERSION=cu128 # Set the cuda version
cargo run --example wgan-mnist --release --features tch-gpu

# Tch CPU backend
cargo run --example wgan-mnist --release --features tch-cpu

# Flex backend (CPU)
cargo run --example wgan-mnist --release --features flex                   # f32

Generating

To generate a sample of images, you can use wgan-generate. The same feature flags are used to select a backend.

cargo run --example wgan-generate --release --features cuda