feat(wgpu): use WgpuRuntime compiler generic to enable specialized aliases (#5001)

* WIP

* Add docs

* Fix feature flags

* Fix device type

* Fix compilation

* Update revs

* Fmt
This commit is contained in:
Nathaniel Simard
2026-05-25 14:02:04 -04:00
committed by GitHub
parent 70660fedce
commit cd8145ee15
12 changed files with 423 additions and 246 deletions

66
Cargo.lock generated
View File

@@ -2119,7 +2119,7 @@ dependencies = [
[[package]]
name = "cubecl"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"cubecl-core",
"cubecl-cpu",
@@ -2135,7 +2135,7 @@ dependencies = [
[[package]]
name = "cubecl-common"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"backtrace",
"bytemuck",
@@ -2176,7 +2176,7 @@ dependencies = [
[[package]]
name = "cubecl-core"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"bitflags 2.11.1",
"bytemuck",
@@ -2205,7 +2205,7 @@ dependencies = [
[[package]]
name = "cubecl-cpp"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"bytemuck",
"cubecl-common",
@@ -2221,7 +2221,7 @@ dependencies = [
[[package]]
name = "cubecl-cpu"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"bytemuck",
"cubecl-common",
@@ -2242,7 +2242,7 @@ dependencies = [
[[package]]
name = "cubecl-cuda"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"bytemuck",
"cubecl-common",
@@ -2260,7 +2260,7 @@ dependencies = [
[[package]]
name = "cubecl-hip"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"bytemuck",
"cubecl-common",
@@ -2289,7 +2289,7 @@ dependencies = [
[[package]]
name = "cubecl-ir"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"bumpalo",
"cubecl-common",
@@ -2313,7 +2313,7 @@ dependencies = [
[[package]]
name = "cubecl-macros"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"cubecl-common",
"darling 0.23.0",
@@ -2330,7 +2330,7 @@ dependencies = [
[[package]]
name = "cubecl-macros-internal"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"darling 0.23.0",
"proc-macro2",
@@ -2341,7 +2341,7 @@ dependencies = [
[[package]]
name = "cubecl-opt"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"cubecl-common",
"cubecl-core",
@@ -2359,7 +2359,7 @@ dependencies = [
[[package]]
name = "cubecl-runtime"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"ahash",
"async-channel",
@@ -2390,7 +2390,7 @@ dependencies = [
[[package]]
name = "cubecl-spirv"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"bitflags 2.11.1",
"cubecl-common",
@@ -2406,7 +2406,7 @@ dependencies = [
[[package]]
name = "cubecl-std"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"cubecl-common",
"cubecl-core",
@@ -2422,7 +2422,7 @@ dependencies = [
[[package]]
name = "cubecl-wgpu"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"ash",
"async-channel",
@@ -2451,7 +2451,7 @@ dependencies = [
[[package]]
name = "cubecl-zspace"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
dependencies = [
"derive-new",
"serde",
@@ -2461,7 +2461,7 @@ dependencies = [
[[package]]
name = "cubek"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
dependencies = [
"cubecl",
"cubek-attention",
@@ -2479,7 +2479,7 @@ dependencies = [
[[package]]
name = "cubek-attention"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
dependencies = [
"bytemuck",
"cubecl",
@@ -2493,7 +2493,7 @@ dependencies = [
[[package]]
name = "cubek-convolution"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
dependencies = [
"bytemuck",
"cubecl",
@@ -2509,7 +2509,7 @@ dependencies = [
[[package]]
name = "cubek-fft"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
dependencies = [
"cubecl",
]
@@ -2517,7 +2517,7 @@ dependencies = [
[[package]]
name = "cubek-interpolate"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
dependencies = [
"cubecl",
"cubecl-common",
@@ -2528,7 +2528,7 @@ dependencies = [
[[package]]
name = "cubek-matmul"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
dependencies = [
"bytemuck",
"cubecl",
@@ -2541,7 +2541,7 @@ dependencies = [
[[package]]
name = "cubek-pool"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
dependencies = [
"cubecl",
"cubecl-common",
@@ -2551,7 +2551,7 @@ dependencies = [
[[package]]
name = "cubek-quant"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
dependencies = [
"cubecl",
"cubecl-common",
@@ -2562,7 +2562,7 @@ dependencies = [
[[package]]
name = "cubek-random"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
dependencies = [
"cubecl",
"cubecl-common",
@@ -2575,7 +2575,7 @@ dependencies = [
[[package]]
name = "cubek-reduce"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
dependencies = [
"cubecl",
"cubek-std",
@@ -2588,7 +2588,7 @@ dependencies = [
[[package]]
name = "cubek-std"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
dependencies = [
"cubecl",
"cubecl-common",
@@ -4869,9 +4869,9 @@ checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682"
[[package]]
name = "jiff"
version = "0.2.24"
version = "0.2.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f00b5dbd620d61dfdcb6007c9c1f6054ebd75319f163d886a9055cec1155073d"
checksum = "6835eea34fb6321b9b3aa7b685c2b433948c09447e389dc017fdf687d5d11e65"
dependencies = [
"jiff-static",
"log",
@@ -4882,9 +4882,9 @@ dependencies = [
[[package]]
name = "jiff-static"
version = "0.2.24"
version = "0.2.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e000de030ff8022ea1da3f466fbb0f3a809f5e51ed31f6dd931c35181ad8e6d7"
checksum = "3c22e04db9c58f5136eb1757f3d5c49a7b187f49e52185228cbd2f5acdfcc08c"
dependencies = [
"proc-macro2",
"quote",
@@ -5132,9 +5132,9 @@ dependencies = [
[[package]]
name = "log"
version = "0.4.29"
version = "0.4.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
checksum = "616ec5685824bcc94416c6d4a7a446eea774a31efd7062c8480ba6fd06d7a6e5"
[[package]]
name = "loop9"

View File

@@ -219,10 +219,10 @@ burn-vision = { path = "crates/burn-vision", version = "0.22.0-pre.1", default-f
burn-wgpu = { path = "crates/burn-wgpu", version = "0.22.0-pre.1", default-features = false }
### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9a783fdef4997d18c6333600c049f898b0b4fc7c" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9a783fdef4997d18c6333600c049f898b0b4fc7c" }
cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9a783fdef4997d18c6333600c049f898b0b4fc7c" }
cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "b4a62bd93b677f83f95108abb5bb893be75729d7" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b5983f200457271d7e336f0db903eda54834b924" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b5983f200457271d7e336f0db903eda54834b924" }
cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b5983f200457271d7e336f0db903eda54834b924" }
cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "deba7f003af37a231e38bebda293a89df338f467" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }

View File

@@ -1,7 +1,4 @@
fn main() {
println!("cargo::rustc-check-cfg=cfg(wgpu_metal)");
println!("cargo::rustc-check-cfg=cfg(wgpu_vulkan)");
println!("cargo::rustc-check-cfg=cfg(wgpu_webgpu)");
println!("cargo::rustc-check-cfg=cfg(default_backend)");
// If you try to build with `--no-default-features`, we enable a cpu backend by default (Flex)
@@ -11,47 +8,15 @@ fn main() {
let ndarray = cfg!(feature = "ndarray");
let tch = cfg!(feature = "tch");
let cpu = cfg!(feature = "cpu");
let mut metal = cfg!(feature = "metal");
let mut vulkan = cfg!(feature = "vulkan");
let mut webgpu = cfg!(feature = "webgpu");
let metal = cfg!(feature = "metal");
let vulkan = cfg!(feature = "vulkan");
let webgpu = cfg!(feature = "webgpu");
let wgpu = cfg!(feature = "wgpu");
// Detect which single wgpu backend is enabled
let wgpu_only = cfg!(all(
feature = "wgpu",
not(feature = "metal"),
not(feature = "vulkan")
));
let enabled = [(metal, "metal"), (vulkan, "vulkan"), (webgpu, "webgpu")]
.iter()
.filter(|x| x.0)
.map(|x| x.1)
.collect::<Vec<_>>();
// WGPU features are mutually exclusive, but we don't want to workspace to throw a compile error.
// In workspace builds with multiple features, we emit a warning and fallback to WebGpu/Wgpu.
if enabled.len() > 1 {
webgpu = true;
vulkan = false;
metal = false;
println!(
"cargo:warning=Only one wgpu backend can be enabled at once. Detected: [{}]. Falling back to `wgpu`. For production, enable only one of: metal, vulkan, or webgpu.",
enabled.join(", ")
);
}
let no_backend_enabled =
!(cuda || flex || rocm || ndarray || tch || cpu || metal || vulkan || webgpu || wgpu_only);
!(cuda || flex || rocm || ndarray || tch || cpu || metal || vulkan || webgpu || wgpu);
if metal {
println!("cargo:rustc-cfg=wgpu_metal");
}
if vulkan {
println!("cargo:rustc-cfg=wgpu_vulkan");
}
if webgpu || wgpu_only {
println!("cargo:rustc-cfg=wgpu_webgpu");
}
if no_backend_enabled {
println!("cargo:rustc-cfg=default_backend");
}

View File

@@ -92,8 +92,16 @@ impl Backend for Dispatch {
DispatchDeviceId::Cpu => Cpu::device_count(backend_type_id),
#[cfg(feature = "cuda")]
DispatchDeviceId::Cuda => Cuda::device_count(backend_type_id),
#[cfg(feature = "metal")]
DispatchDeviceId::Metal => Metal::device_count(backend_type_id),
#[cfg(feature = "rocm")]
DispatchDeviceId::Rocm => Rocm::device_count(backend_type_id),
#[cfg(feature = "vulkan")]
DispatchDeviceId::Vulkan => Vulkan::device_count(backend_type_id),
#[cfg(feature = "wgpu")]
DispatchDeviceId::Wgpu => Wgpu::device_count(backend_type_id),
#[cfg(feature = "webgpu")]
DispatchDeviceId::WebGpu => WebGpu::device_count(backend_type_id),
#[cfg(feature = "flex")]
DispatchDeviceId::Flex => Flex::device_count(backend_type_id),
#[cfg(any(feature = "ndarray", default_backend))]
@@ -150,10 +158,16 @@ impl AutodiffBackend for Dispatch {
DispatchTensorKind::Cpu(tensor) => tensor.autodiff().backward(),
#[cfg(feature = "cuda")]
DispatchTensorKind::Cuda(tensor) => tensor.autodiff().backward(),
#[cfg(feature = "metal")]
DispatchTensorKind::Metal(tensor) => tensor.autodiff().backward(),
#[cfg(feature = "rocm")]
DispatchTensorKind::Rocm(tensor) => tensor.autodiff().backward(),
#[cfg(feature = "vulkan")]
DispatchTensorKind::Vulkan(tensor) => tensor.autodiff().backward(),
#[cfg(feature = "wgpu")]
DispatchTensorKind::Wgpu(tensor) => tensor.autodiff().backward(),
#[cfg(feature = "webgpu")]
DispatchTensorKind::WebGpu(tensor) => tensor.autodiff().backward(),
#[cfg(feature = "flex")]
DispatchTensorKind::Flex(tensor) => tensor.autodiff().backward(),
#[cfg(any(feature = "ndarray", default_backend))]
@@ -187,16 +201,31 @@ impl AutodiffBackend for Dispatch {
.as_autodiff()
.grad(grads)
.map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))),
#[cfg(feature = "metal")]
DispatchTensorKind::Metal(tensor) => tensor
.as_autodiff()
.grad(grads)
.map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))),
#[cfg(feature = "rocm")]
DispatchTensorKind::Rocm(tensor) => tensor
.as_autodiff()
.grad(grads)
.map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))),
#[cfg(feature = "vulkan")]
DispatchTensorKind::Vulkan(tensor) => tensor
.as_autodiff()
.grad(grads)
.map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))),
#[cfg(feature = "wgpu")]
DispatchTensorKind::Wgpu(tensor) => tensor
.as_autodiff()
.grad(grads)
.map(|t| DispatchTensorKind::Wgpu(crate::BackendTensor::Float(t))),
#[cfg(feature = "webgpu")]
DispatchTensorKind::WebGpu(tensor) => tensor
.as_autodiff()
.grad(grads)
.map(|t| DispatchTensorKind::WebGpu(crate::BackendTensor::Float(t))),
#[cfg(feature = "flex")]
DispatchTensorKind::Flex(tensor) => tensor
.as_autodiff()
@@ -246,16 +275,31 @@ impl AutodiffBackend for Dispatch {
.as_autodiff()
.grad_remove(grads)
.map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))),
#[cfg(feature = "metal")]
DispatchTensorKind::Metal(tensor) => tensor
.as_autodiff()
.grad_remove(grads)
.map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))),
#[cfg(feature = "rocm")]
DispatchTensorKind::Rocm(tensor) => tensor
.as_autodiff()
.grad_remove(grads)
.map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))),
#[cfg(feature = "vulkan")]
DispatchTensorKind::Vulkan(tensor) => tensor
.as_autodiff()
.grad_remove(grads)
.map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))),
#[cfg(feature = "wgpu")]
DispatchTensorKind::Wgpu(tensor) => tensor
.as_autodiff()
.grad_remove(grads)
.map(|t| DispatchTensorKind::Wgpu(crate::BackendTensor::Float(t))),
#[cfg(feature = "webgpu")]
DispatchTensorKind::WebGpu(tensor) => tensor
.as_autodiff()
.grad_remove(grads)
.map(|t| DispatchTensorKind::WebGpu(crate::BackendTensor::Float(t))),
#[cfg(feature = "flex")]
DispatchTensorKind::Flex(tensor) => tensor
.as_autodiff()
@@ -309,14 +353,26 @@ impl AutodiffBackend for Dispatch {
(DispatchTensorKind::Cuda(tensor), DispatchTensorKind::Cuda(grad)) => {
tensor.as_autodiff().grad_replace(grads, grad.float())
}
#[cfg(feature = "metal")]
(DispatchTensorKind::Metal(tensor), DispatchTensorKind::Metal(grad)) => {
tensor.as_autodiff().grad_replace(grads, grad.float())
}
#[cfg(feature = "rocm")]
(DispatchTensorKind::Rocm(tensor), DispatchTensorKind::Rocm(grad)) => {
tensor.as_autodiff().grad_replace(grads, grad.float())
}
#[cfg(feature = "vulkan")]
(DispatchTensorKind::Vulkan(tensor), DispatchTensorKind::Vulkan(grad)) => {
tensor.as_autodiff().grad_replace(grads, grad.float())
}
#[cfg(feature = "wgpu")]
(DispatchTensorKind::Wgpu(tensor), DispatchTensorKind::Wgpu(grad)) => {
tensor.as_autodiff().grad_replace(grads, grad.float())
}
#[cfg(feature = "webgpu")]
(DispatchTensorKind::WebGpu(tensor), DispatchTensorKind::WebGpu(grad)) => {
tensor.as_autodiff().grad_replace(grads, grad.float())
}
#[cfg(feature = "flex")]
(DispatchTensorKind::Flex(tensor), DispatchTensorKind::Flex(grad)) => {
tensor.as_autodiff().grad_replace(grads, grad.float())
@@ -356,14 +412,26 @@ impl AutodiffBackend for Dispatch {
DispatchTensorKind::Cuda(tensor) => DispatchTensorKind::Cuda(
crate::BackendTensor::Float(tensor.autodiff().primitive),
),
#[cfg(feature = "metal")]
DispatchTensorKind::Metal(tensor) => DispatchTensorKind::Metal(
crate::BackendTensor::Float(tensor.autodiff().primitive),
),
#[cfg(feature = "rocm")]
DispatchTensorKind::Rocm(tensor) => DispatchTensorKind::Rocm(
crate::BackendTensor::Float(tensor.autodiff().primitive),
),
#[cfg(feature = "vulkan")]
DispatchTensorKind::Vulkan(tensor) => DispatchTensorKind::Vulkan(
crate::BackendTensor::Float(tensor.autodiff().primitive),
),
#[cfg(feature = "wgpu")]
DispatchTensorKind::Wgpu(tensor) => DispatchTensorKind::Wgpu(
crate::BackendTensor::Float(tensor.autodiff().primitive),
),
#[cfg(feature = "webgpu")]
DispatchTensorKind::WebGpu(tensor) => DispatchTensorKind::WebGpu(
crate::BackendTensor::Float(tensor.autodiff().primitive),
),
#[cfg(feature = "flex")]
DispatchTensorKind::Flex(tensor) => DispatchTensorKind::Flex(
crate::BackendTensor::Float(tensor.autodiff().primitive),
@@ -423,18 +491,36 @@ impl AutodiffBackend for Dispatch {
crate::BackendTensor::Autodiff(Autodiff::<Cuda>::from_inner(tensor.float())),
)))
}
#[cfg(feature = "metal")]
DispatchTensorKind::Metal(tensor) => {
DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::Metal(
crate::BackendTensor::Autodiff(Autodiff::<Metal>::from_inner(tensor.float())),
)))
}
#[cfg(feature = "rocm")]
DispatchTensorKind::Rocm(tensor) => {
DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::Rocm(
crate::BackendTensor::Autodiff(Autodiff::<Rocm>::from_inner(tensor.float())),
)))
}
#[cfg(feature = "vulkan")]
DispatchTensorKind::Vulkan(tensor) => {
DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::Vulkan(
crate::BackendTensor::Autodiff(Autodiff::<Vulkan>::from_inner(tensor.float())),
)))
}
#[cfg(feature = "wgpu")]
DispatchTensorKind::Wgpu(tensor) => {
DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::Wgpu(
crate::BackendTensor::Autodiff(Autodiff::<Wgpu>::from_inner(tensor.float())),
)))
}
#[cfg(feature = "webgpu")]
DispatchTensorKind::WebGpu(tensor) => {
DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::WebGpu(
crate::BackendTensor::Autodiff(Autodiff::<WebGpu>::from_inner(tensor.float())),
)))
}
#[cfg(feature = "flex")]
DispatchTensorKind::Flex(tensor) => {
DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::Flex(
@@ -548,10 +634,16 @@ impl DispatchTensorKind {
DispatchTensorKind::Cpu(tensor) => DispatchDevice::Cpu(tensor.device()),
#[cfg(feature = "cuda")]
DispatchTensorKind::Cuda(tensor) => DispatchDevice::Cuda(tensor.device()),
#[cfg(feature = "metal")]
DispatchTensorKind::Metal(tensor) => DispatchDevice::Metal(tensor.device()),
#[cfg(feature = "rocm")]
DispatchTensorKind::Rocm(tensor) => DispatchDevice::Rocm(tensor.device()),
#[cfg(feature = "vulkan")]
DispatchTensorKind::Vulkan(tensor) => DispatchDevice::Vulkan(tensor.device()),
#[cfg(feature = "wgpu")]
DispatchTensorKind::Wgpu(tensor) => DispatchDevice::Wgpu(tensor.device()),
#[cfg(feature = "webgpu")]
DispatchTensorKind::WebGpu(tensor) => DispatchDevice::WebGpu(tensor.device()),
#[cfg(feature = "flex")]
DispatchTensorKind::Flex(tensor) => DispatchDevice::Flex(tensor.device()),
#[cfg(any(feature = "ndarray", default_backend))]
@@ -597,13 +689,25 @@ impl Dispatch {
DispatchDeviceId::Cuda => (0..Cuda::device_count(0))
.map(|i| CudaDevice::new(i).into())
.collect(),
#[cfg(feature = "metal")]
DispatchDeviceId::Metal => (0..Metal::device_count(0))
.map(|i| DispatchDevice::Metal(WgpuDevice::DiscreteGpu(i)))
.collect(),
#[cfg(feature = "rocm")]
DispatchDeviceId::Rocm => (0..Rocm::device_count(0))
.map(|i| RocmDevice::new(i).into())
.collect(),
#[cfg(feature = "vulkan")]
DispatchDeviceId::Vulkan => (0..Vulkan::device_count(0))
.map(|i| DispatchDevice::Vulkan(WgpuDevice::DiscreteGpu(i)))
.collect(),
#[cfg(feature = "wgpu")]
DispatchDeviceId::Wgpu => (0..Wgpu::device_count(0))
.map(|i| WgpuDevice::DiscreteGpu(i).into())
.map(|i| DispatchDevice::Wgpu(WgpuDevice::DiscreteGpu(i)))
.collect(),
#[cfg(feature = "webgpu")]
DispatchDeviceId::WebGpu => (0..WebGpu::device_count(0))
.map(|i| DispatchDevice::WebGpu(WgpuDevice::DiscreteGpu(i)))
.collect(),
#[cfg(feature = "flex")]
DispatchDeviceId::Flex => vec![FlexDevice.into()],

View File

@@ -30,25 +30,26 @@ pub enum DispatchDevice {
#[cfg(feature = "cuda")]
Cuda(CudaDevice),
/// The [Metal backend](Metal) device (via WGPU runtime).
#[cfg(feature = "metal")]
Metal(WgpuDevice),
/// The [ROCm backend](Rocm) device.
#[cfg(feature = "rocm")]
Rocm(RocmDevice),
#[cfg_attr(
wgpu_webgpu,
doc = r#"The [WebGPU backend](Wgpu) device (via WGPU runtime)"#
)]
#[cfg_attr(
wgpu_vulkan,
doc = r#"The [Vulkan backend](Vulkan) device (via WGPU runtime)"#
)]
#[cfg_attr(
wgpu_metal,
doc = r#"The [Metal backend](Metal) device (via WGPU runtime)"#
)]
/// The [Vulkan backend](Vulkan) device.
#[cfg(feature = "vulkan")]
Vulkan(WgpuDevice),
/// The [Wgpu backend](Wgpu) device (via WGPU runtime with auto-selected compiler).
#[cfg(feature = "wgpu")]
Wgpu(WgpuDevice),
/// The [WebGPU backend](WebGpu) device (via WGPU runtime).
#[cfg(feature = "webgpu")]
WebGpu(WgpuDevice),
/// The [Flex backend](Flex) device (CPU-only).
#[cfg(feature = "flex")]
Flex(FlexDevice),
@@ -145,14 +146,16 @@ impl core::fmt::Debug for DispatchDevice {
Self::Cpu(device) => f.debug_tuple("Cpu").field(device).finish(),
#[cfg(feature = "cuda")]
Self::Cuda(device) => f.debug_tuple("Cuda").field(device).finish(),
#[cfg(feature = "metal")]
Self::Metal(device) => f.debug_tuple("Metal").field(device).finish(),
#[cfg(feature = "rocm")]
Self::Rocm(device) => f.debug_tuple("Rocm").field(device).finish(),
#[cfg(wgpu_metal)]
Self::Wgpu(device) => f.debug_tuple("Metal").field(device).finish(),
#[cfg(wgpu_vulkan)]
Self::Wgpu(device) => f.debug_tuple("Vulkan").field(device).finish(),
#[cfg(wgpu_webgpu)]
#[cfg(feature = "vulkan")]
Self::Vulkan(device) => f.debug_tuple("Vulkan").field(device).finish(),
#[cfg(feature = "wgpu")]
Self::Wgpu(device) => f.debug_tuple("Wgpu").field(device).finish(),
#[cfg(feature = "webgpu")]
Self::WebGpu(device) => f.debug_tuple("WebGpu").field(device).finish(),
#[cfg(feature = "flex")]
Self::Flex(device) => f.debug_tuple("Flex").field(device).finish(),
#[cfg(any(feature = "ndarray", default_backend))]
@@ -189,8 +192,8 @@ impl Default for DispatchDevice {
);
}
"metal" => {
#[cfg(wgpu_metal)]
return Self::Wgpu(burn_wgpu::WgpuDevice::default());
#[cfg(feature = "metal")]
return Self::Metal(burn_wgpu::WgpuDevice::default());
panic!(
"BURN_DEVICE=metal requested, but the 'metal' feature is not enabled."
);
@@ -203,17 +206,24 @@ impl Default for DispatchDevice {
);
}
"vulkan" => {
#[cfg(wgpu_vulkan)]
return Self::Wgpu(burn_wgpu::WgpuDevice::default());
#[cfg(feature = "vulkan")]
return Self::Vulkan(burn_wgpu::WgpuDevice::default());
panic!(
"BURN_DEVICE=vulkan requested, but the 'vulkan' feature is not enabled."
);
}
"webgpu" | "wgpu" => {
#[cfg(wgpu_webgpu)]
"webgpu" => {
#[cfg(feature = "webgpu")]
return Self::WebGpu(burn_wgpu::WgpuDevice::default());
panic!(
"BURN_DEVICE=webgpu requested, but the 'webgpu' feature is not enabled."
);
}
"wgpu" => {
#[cfg(feature = "wgpu")]
return Self::Wgpu(burn_wgpu::WgpuDevice::default());
panic!(
"BURN_DEVICE=wgpu requested, but the 'webgpu' or 'wgpu' feature is not enabled."
"BURN_DEVICE=wgpu requested, but the 'wgpu' feature is not enabled."
);
}
"cpu" => {
@@ -255,12 +265,21 @@ impl Default for DispatchDevice {
#[cfg(feature = "cuda")]
return Self::Cuda(CudaDevice::default());
#[cfg(feature = "metal")]
return Self::Metal(burn_wgpu::WgpuDevice::default());
#[cfg(feature = "rocm")]
return Self::Rocm(RocmDevice::default());
#[cfg(feature = "vulkan")]
return Self::Vulkan(burn_wgpu::WgpuDevice::default());
#[cfg(feature = "wgpu")]
return Self::Wgpu(burn_wgpu::WgpuDevice::default());
#[cfg(feature = "webgpu")]
return Self::WebGpu(burn_wgpu::WgpuDevice::default());
#[cfg(feature = "cpu")]
return Self::Cpu(CpuDevice);
@@ -296,10 +315,16 @@ impl PartialEq for DispatchDevice {
(Self::Cpu(a), Self::Cpu(b)) => a == b,
#[cfg(feature = "cuda")]
(Self::Cuda(a), Self::Cuda(b)) => a == b,
#[cfg(feature = "metal")]
(Self::Metal(a), Self::Metal(b)) => a == b,
#[cfg(feature = "rocm")]
(Self::Rocm(a), Self::Rocm(b)) => a == b,
#[cfg(feature = "vulkan")]
(Self::Vulkan(a), Self::Vulkan(b)) => a == b,
#[cfg(feature = "wgpu")]
(Self::Wgpu(a), Self::Wgpu(b)) => a == b,
#[cfg(feature = "webgpu")]
(Self::WebGpu(a), Self::WebGpu(b)) => a == b,
#[cfg(feature = "flex")]
(Self::Flex(a), Self::Flex(b)) => a == b,
#[cfg(any(feature = "ndarray", default_backend))]
@@ -350,10 +375,16 @@ impl DispatchDevice {
Self::Cpu(_) => DispatchDeviceId::Cpu,
#[cfg(feature = "cuda")]
Self::Cuda(_) => DispatchDeviceId::Cuda,
#[cfg(feature = "metal")]
Self::Metal(_) => DispatchDeviceId::Metal,
#[cfg(feature = "rocm")]
Self::Rocm(_) => DispatchDeviceId::Rocm,
#[cfg(feature = "vulkan")]
Self::Vulkan(_) => DispatchDeviceId::Vulkan,
#[cfg(feature = "wgpu")]
Self::Wgpu(_) => DispatchDeviceId::Wgpu,
#[cfg(feature = "webgpu")]
Self::WebGpu(_) => DispatchDeviceId::WebGpu,
#[cfg(feature = "flex")]
Self::Flex(_) => DispatchDeviceId::Flex,
#[cfg(any(feature = "ndarray", default_backend))]
@@ -393,13 +424,15 @@ impl DispatchDevice {
pub enum DispatchDeviceId {
Cpu = 0,
Cuda = 1,
// webgpu, vulkan and device all point to WgpuDevice
Wgpu = 2,
Rocm = 3,
Flex = 4,
LibTorch = 5,
NdArray = 6,
Remote = 7,
Metal = 7,
Vulkan = 8,
WebGpu = 9,
Remote = 10,
}
impl From<DispatchDeviceId> for u16 {
@@ -427,8 +460,14 @@ impl TryFrom<u16> for DispatchDeviceId {
5 => Ok(Self::LibTorch),
#[cfg(any(feature = "ndarray", default_backend))]
6 => Ok(Self::NdArray),
#[cfg(feature = "metal")]
7 => Ok(Self::Metal),
#[cfg(feature = "vulkan")]
8 => Ok(Self::Vulkan),
#[cfg(feature = "webgpu")]
9 => Ok(Self::WebGpu),
#[cfg(feature = "remote")]
7 => Ok(Self::Remote),
10 => Ok(Self::Remote),
_ => Err(()),
}
}
@@ -441,10 +480,16 @@ impl DeviceOps for DispatchDevice {
Self::Cpu(device) => device.defaults(),
#[cfg(feature = "cuda")]
Self::Cuda(device) => device.defaults(),
#[cfg(feature = "metal")]
Self::Metal(device) => device.defaults(),
#[cfg(feature = "rocm")]
Self::Rocm(device) => device.defaults(),
#[cfg(feature = "vulkan")]
Self::Vulkan(device) => device.defaults(),
#[cfg(feature = "wgpu")]
Self::Wgpu(device) => device.defaults(),
#[cfg(feature = "webgpu")]
Self::WebGpu(device) => device.defaults(),
#[cfg(feature = "flex")]
Self::Flex(device) => device.defaults(),
#[cfg(any(feature = "ndarray", default_backend))]
@@ -469,10 +514,16 @@ impl burn_backend::Device for DispatchDevice {
DispatchDeviceId::Cpu => Self::Cpu(CpuDevice::from_id(device_id)),
#[cfg(feature = "cuda")]
DispatchDeviceId::Cuda => Self::Cuda(CudaDevice::from_id(device_id)),
#[cfg(feature = "metal")]
DispatchDeviceId::Metal => Self::Metal(WgpuDevice::from_id(device_id)),
#[cfg(feature = "rocm")]
DispatchDeviceId::Rocm => Self::Rocm(RocmDevice::from_id(device_id)),
#[cfg(feature = "vulkan")]
DispatchDeviceId::Vulkan => Self::Vulkan(WgpuDevice::from_id(device_id)),
#[cfg(feature = "wgpu")]
DispatchDeviceId::Wgpu => Self::Wgpu(WgpuDevice::from_id(device_id)),
#[cfg(feature = "webgpu")]
DispatchDeviceId::WebGpu => Self::WebGpu(WgpuDevice::from_id(device_id)),
#[cfg(feature = "flex")]
DispatchDeviceId::Flex => Self::Flex(FlexDevice::from_id(device_id)),
#[cfg(any(feature = "ndarray", default_backend))]
@@ -491,10 +542,16 @@ impl burn_backend::Device for DispatchDevice {
Self::Cpu(device) => device.to_id(),
#[cfg(feature = "cuda")]
Self::Cuda(device) => device.to_id(),
#[cfg(feature = "metal")]
Self::Metal(device) => device.to_id(),
#[cfg(feature = "rocm")]
Self::Rocm(device) => device.to_id(),
#[cfg(feature = "vulkan")]
Self::Vulkan(device) => device.to_id(),
#[cfg(feature = "wgpu")]
Self::Wgpu(device) => device.to_id(),
#[cfg(feature = "webgpu")]
Self::WebGpu(device) => device.to_id(),
#[cfg(feature = "flex")]
Self::Flex(device) => device.to_id(),
#[cfg(any(feature = "ndarray", default_backend))]
@@ -525,13 +582,6 @@ impl From<CudaDevice> for DispatchDevice {
}
}
#[cfg(wgpu_metal)]
impl From<WgpuDevice> for DispatchDevice {
fn from(device: WgpuDevice) -> Self {
DispatchDevice::Wgpu(device)
}
}
#[cfg(feature = "rocm")]
impl From<RocmDevice> for DispatchDevice {
fn from(device: RocmDevice) -> Self {
@@ -539,14 +589,9 @@ impl From<RocmDevice> for DispatchDevice {
}
}
#[cfg(wgpu_vulkan)]
impl From<WgpuDevice> for DispatchDevice {
fn from(device: WgpuDevice) -> Self {
DispatchDevice::Wgpu(device)
}
}
#[cfg(wgpu_webgpu)]
// A bare `WgpuDevice` maps to the auto-compiler [`DispatchDevice::Wgpu`] variant. To target a
// specific wgpu specialization (Metal, Vulkan, WebGpu) construct the variant explicitly.
#[cfg(feature = "wgpu")]
impl From<WgpuDevice> for DispatchDevice {
fn from(device: WgpuDevice) -> Self {
DispatchDevice::Wgpu(device)

View File

@@ -22,21 +22,9 @@
//! | `LibTorch` | `tch` | Libtorch backend via `tch` |
//! | `Autodiff` | `autodiff` | Autodiff-enabled backend (used in combination with any of the backends above) |
//!
//! **Note:** WGPU-based backends (`metal`, `vulkan`, `webgpu`) are mutually exclusive.
//! All other backends can be combined freely.
//!
//! ## WGPU Backend Exclusivity
//!
//! The WGPU-based backends (`metal`, `vulkan`, `webgpu`) are **mutually exclusive** due to
//! the current automatic compile, which can only select one target at a time.
//!
//! Enable only **one** of these features in your `Cargo.toml`:
//! - `metal`
//! - `vulkan`
//! - `webgpu`
//!
//! If multiple WGPU features are enabled, the build script will emit a warning and enable `webgpu` only
//! to prevent unintended behavior.
//! **Note:** All backends, including the WGPU-based ones (`wgpu`, `metal`, `vulkan`, `webgpu`),
//! can be combined freely. Each enabled wgpu backend appears as its own
//! [`DispatchDevice`] variant.
#[macro_use]
mod macros;
@@ -80,11 +68,11 @@ pub mod backends {
pub use burn_rocm::Rocm;
#[cfg(feature = "wgpu")]
pub use burn_wgpu as wgpu;
#[cfg(wgpu_metal)]
#[cfg(feature = "metal")]
pub use burn_wgpu::Metal;
#[cfg(wgpu_vulkan)]
#[cfg(feature = "vulkan")]
pub use burn_wgpu::Vulkan;
#[cfg(all(wgpu_webgpu, feature = "webgpu"))]
#[cfg(feature = "webgpu")]
pub use burn_wgpu::WebGpu;
#[cfg(feature = "wgpu")]
pub use burn_wgpu::Wgpu;

View File

@@ -1,3 +1,44 @@
/// Resolves a backend identifier to its concrete `Backend`-impl type.
///
/// Now that element type parameters have been removed project-wide, this is a simple
/// passthrough — kept so the dispatch macros stay agnostic to future signature changes.
#[allow(unused_macros)]
macro_rules! inst {
(Cpu) => {
$crate::backends::Cpu
};
(Cuda) => {
$crate::backends::Cuda
};
(Metal) => {
$crate::backends::Metal
};
(Rocm) => {
$crate::backends::Rocm
};
(Vulkan) => {
$crate::backends::Vulkan
};
(Wgpu) => {
$crate::backends::Wgpu
};
(WebGpu) => {
$crate::backends::WebGpu
};
(Flex) => {
$crate::backends::Flex
};
(NdArray) => {
$crate::backends::NdArray
};
(LibTorch) => {
$crate::backends::LibTorch
};
(Remote) => {
$crate::backends::Remote
};
}
/// Supplies a list of all supported backends and their corresponding feature flags
/// to a callback macro. This centralizes the backend registry.
macro_rules! backend_list {
@@ -6,8 +47,11 @@ macro_rules! backend_list {
$($extra)*;
[Cpu, feature = "cpu"],
[Cuda, feature = "cuda"],
[Metal, feature = "metal"],
[Rocm, feature = "rocm"],
[Vulkan, feature = "vulkan"],
[Wgpu, feature = "wgpu"],
[WebGpu, feature = "webgpu"],
[Flex, feature = "flex"],
[NdArray, any(feature = "ndarray", default_backend)],
[LibTorch, feature = "tch"],
@@ -21,14 +65,17 @@ macro_rules! backend_matrix {
($callback:ident, $($extra:tt)*) => {
$callback! {
$($extra)*;
[Cpu, feature = "cpu"] => [[Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [Wgpu, feature = "wgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[Cuda, feature = "cuda"] => [[Cpu, feature = "cpu"], [Rocm, feature = "rocm"], [Wgpu, feature = "wgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[Rocm, feature = "rocm"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Wgpu, feature = "wgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[Wgpu, feature = "wgpu"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[Flex, feature = "flex"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [Wgpu, feature = "wgpu"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[NdArray, any(feature = "ndarray", default_backend)] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [Wgpu, feature = "wgpu"], [Flex, feature = "flex"], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[LibTorch, feature = "tch"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [Wgpu, feature = "wgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [Remote, feature = "remote"]];
[Remote, feature = "remote"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [Wgpu, feature = "wgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"]]
[Cpu, feature = "cpu"] => [[Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[Cuda, feature = "cuda"] => [[Cpu, feature = "cpu"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[Metal, feature = "metal"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[Rocm, feature = "rocm"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[Vulkan, feature = "vulkan"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[Wgpu, feature = "wgpu"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[WebGpu, feature = "webgpu"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[Flex, feature = "flex"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[NdArray, any(feature = "ndarray", default_backend)] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
[LibTorch, feature = "tch"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [Remote, feature = "remote"]];
[Remote, feature = "remote"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"]]
}
};
}

View File

@@ -19,10 +19,22 @@ pub fn start_websocket(device: DispatchDevice, port: u16) {
DispatchDevice::Cpu(device) => burn_remote::server::start_websocket::<Cpu>(device, port),
#[cfg(feature = "cuda")]
DispatchDevice::Cuda(device) => burn_remote::server::start_websocket::<Cuda>(device, port),
#[cfg(feature = "metal")]
DispatchDevice::Metal(device) => {
burn_remote::server::start_websocket::<Metal>(device, port)
}
#[cfg(feature = "rocm")]
DispatchDevice::Rocm(device) => burn_remote::server::start_websocket::<Rocm>(device, port),
#[cfg(feature = "vulkan")]
DispatchDevice::Vulkan(device) => {
burn_remote::server::start_websocket::<Vulkan>(device, port)
}
#[cfg(feature = "wgpu")]
DispatchDevice::Wgpu(device) => burn_remote::server::start_websocket::<Wgpu>(device, port),
#[cfg(feature = "webgpu")]
DispatchDevice::WebGpu(device) => {
burn_remote::server::start_websocket::<WebGpu>(device, port)
}
#[cfg(feature = "flex")]
DispatchDevice::Flex(device) => burn_remote::server::start_websocket::<Flex>(device, port),
#[cfg(any(feature = "ndarray", default_backend))]
@@ -57,14 +69,26 @@ pub async fn start_websocket_async(device: DispatchDevice, port: u16) {
DispatchDevice::Cuda(device) => {
burn_remote::server::start_websocket_async::<Cuda>(device, port).await
}
#[cfg(feature = "metal")]
DispatchDevice::Metal(device) => {
burn_remote::server::start_websocket_async::<Metal>(device, port).await
}
#[cfg(feature = "rocm")]
DispatchDevice::Rocm(device) => {
burn_remote::server::start_websocket_async::<Rocm>(device, port).await
}
#[cfg(feature = "vulkan")]
DispatchDevice::Vulkan(device) => {
burn_remote::server::start_websocket_async::<Vulkan>(device, port).await
}
#[cfg(feature = "wgpu")]
DispatchDevice::Wgpu(device) => {
burn_remote::server::start_websocket_async::<Wgpu>(device, port).await
}
#[cfg(feature = "webgpu")]
DispatchDevice::WebGpu(device) => {
burn_remote::server::start_websocket_async::<WebGpu>(device, port).await
}
#[cfg(feature = "flex")]
DispatchDevice::Flex(device) => {
burn_remote::server::start_websocket_async::<Flex>(device, port).await

View File

@@ -191,16 +191,26 @@ pub enum DispatchTensorKind {
#[cfg(feature = "cuda")]
Cuda(BackendTensor<Cuda>),
/// The [Metal backend](Metal) tensor.
#[cfg(feature = "metal")]
Metal(BackendTensor<Metal>),
/// The [ROCm backend](Rocm) tensor.
#[cfg(feature = "rocm")]
Rocm(BackendTensor<Rocm>),
#[cfg_attr(wgpu_webgpu, doc = r#"The [WebGPU backend](Wgpu) tensor"#)]
#[cfg_attr(wgpu_vulkan, doc = r#"The [Vulkan backend](Vulkan) tensor"#)]
#[cfg_attr(wgpu_metal, doc = r#"The [Metal backend](Metal) tensor"#)]
/// The [Vulkan backend](Vulkan) tensor.
#[cfg(feature = "vulkan")]
Vulkan(BackendTensor<Vulkan>),
/// The [Wgpu backend](Wgpu) tensor.
#[cfg(feature = "wgpu")]
Wgpu(BackendTensor<Wgpu>),
/// The [WebGPU backend](Wgpu) tensor.
#[cfg(feature = "webgpu")]
WebGpu(BackendTensor<WebGpu>),
/// The [Flex backend](Flex) tensor.
#[cfg(feature = "flex")]
Flex(BackendTensor<Flex>),
@@ -229,10 +239,16 @@ impl TensorMetadata for DispatchTensorKind {
Self::Cpu(tensor) => tensor.dtype(),
#[cfg(feature = "cuda")]
Self::Cuda(tensor) => tensor.dtype(),
#[cfg(feature = "metal")]
Self::Metal(tensor) => tensor.dtype(),
#[cfg(feature = "rocm")]
Self::Rocm(tensor) => tensor.dtype(),
#[cfg(feature = "vulkan")]
Self::Vulkan(tensor) => tensor.dtype(),
#[cfg(feature = "wgpu")]
Self::Wgpu(tensor) => tensor.dtype(),
#[cfg(feature = "webgpu")]
Self::WebGpu(tensor) => tensor.dtype(),
#[cfg(feature = "flex")]
Self::Flex(tensor) => tensor.dtype(),
#[cfg(any(feature = "ndarray", default_backend))]
@@ -252,10 +268,16 @@ impl TensorMetadata for DispatchTensorKind {
Self::Cpu(tensor) => tensor.shape(),
#[cfg(feature = "cuda")]
Self::Cuda(tensor) => tensor.shape(),
#[cfg(wgpu_metal)]
#[cfg(feature = "metal")]
Self::Metal(tensor) => tensor.shape(),
#[cfg(feature = "rocm")]
Self::Rocm(tensor) => tensor.shape(),
#[cfg(feature = "vulkan")]
Self::Vulkan(tensor) => tensor.shape(),
#[cfg(feature = "wgpu")]
Self::Wgpu(tensor) => tensor.shape(),
#[cfg(feature = "webgpu")]
Self::WebGpu(tensor) => tensor.shape(),
#[cfg(feature = "flex")]
Self::Flex(tensor) => tensor.shape(),
#[cfg(any(feature = "ndarray", default_backend))]

View File

@@ -41,10 +41,10 @@ rocm = ["burn-dispatch/rocm"]
flex = ["burn-dispatch/flex"]
ndarray = ["burn-dispatch/ndarray"]
tch = ["burn-dispatch/tch"]
vulkan = ["burn-dispatch/vulkan"]
webgpu = ["burn-dispatch/webgpu"]
vulkan = ["wgpu", "burn-dispatch/vulkan"]
webgpu = ["wgpu", "burn-dispatch/webgpu"]
wgpu = ["burn-dispatch/wgpu"]
metal = ["burn-dispatch/metal"]
metal = ["wgpu", "burn-dispatch/metal"]
cpu = ["burn-dispatch/cpu"]
autodiff = ["burn-dispatch/autodiff"]
remote = ["std", "burn-dispatch/remote"]

View File

@@ -350,40 +350,34 @@ impl Device {
/// equivalent calls (they exist only to make the intended backend
/// explicit at the call site — the underlying adapter is still picked by
/// enabled Cargo features).
#[cfg(any(
feature = "wgpu",
feature = "vulkan",
feature = "metal",
feature = "webgpu"
))]
#[cfg(feature = "wgpu")]
pub fn wgpu(device_kind: DeviceKind) -> Self {
Self::new(wgpu_device(device_kind))
Self::new(DispatchDevice::Wgpu(wgpu_device(device_kind)))
}
/// Vulkan-backed WGPU device, selected via [`DeviceKind`].
///
/// Equivalent to [`Device::wgpu`]; provided so the caller's choice of
/// graphics API reads clearly at the call site. The actual adapter is
/// still determined by enabled Cargo features.
/// Pins the wgpu shader compiler to SPIR-V at compile time, avoiding
/// the runtime [`AutoCompiler`](burn_dispatch::backends::wgpu::AutoCompiler) dispatch.
#[cfg(feature = "vulkan")]
pub fn vulkan(device_kind: DeviceKind) -> Self {
Self::new(wgpu_device(device_kind))
Self::new(DispatchDevice::Vulkan(wgpu_device(device_kind)))
}
/// Metal-backed WGPU device, selected via [`DeviceKind`].
///
/// See [`Device::vulkan`] — same shape, different feature gate.
/// Pins the wgpu shader compiler to MSL at compile time.
#[cfg(feature = "metal")]
pub fn metal(device_kind: DeviceKind) -> Self {
Self::new(wgpu_device(device_kind))
Self::new(DispatchDevice::Metal(wgpu_device(device_kind)))
}
/// WebGPU-backed device, selected via [`DeviceKind`].
///
/// See [`Device::vulkan`] — same shape, different feature gate.
/// Pins the wgpu shader compiler to WGSL at compile time.
#[cfg(feature = "webgpu")]
pub fn webgpu(device_kind: DeviceKind) -> Self {
Self::new(wgpu_device(device_kind))
Self::new(DispatchDevice::WebGpu(wgpu_device(device_kind)))
}
/// Enables autodiff on this device.
@@ -585,31 +579,34 @@ impl Device {
#[allow(clippy::never_loop)] // at least one backend is expected to be enabled.
for device_type in filter.into() {
#[allow(unused)]
let type_id = match device_type {
let type_ids: &[DispatchDeviceId] = match device_type {
#[cfg(feature = "cpu")]
DeviceType::Cpu => DispatchDeviceId::Cpu,
DeviceType::Cpu => &[DispatchDeviceId::Cpu],
#[cfg(feature = "cuda")]
DeviceType::Cuda => DispatchDeviceId::Cuda,
DeviceType::Cuda => &[DispatchDeviceId::Cuda],
#[cfg(feature = "rocm")]
DeviceType::Rocm => DispatchDeviceId::Rocm,
#[cfg(any(
feature = "wgpu",
feature = "metal",
feature = "vulkan",
feature = "webgpu"
))]
DeviceType::Wgpu => DispatchDeviceId::Wgpu,
DeviceType::Rocm => &[DispatchDeviceId::Rocm],
#[cfg(feature = "wgpu")]
DeviceType::Wgpu => &[DispatchDeviceId::Wgpu],
#[cfg(feature = "metal")]
DeviceType::Metal => &[DispatchDeviceId::Metal],
#[cfg(feature = "vulkan")]
DeviceType::Vulkan => &[DispatchDeviceId::Vulkan],
#[cfg(feature = "webgpu")]
DeviceType::WebGpu => &[DispatchDeviceId::WebGpu],
#[cfg(feature = "flex")]
DeviceType::Flex => DispatchDeviceId::Flex,
DeviceType::Flex => &[DispatchDeviceId::Flex],
#[cfg(feature = "ndarray")]
DeviceType::NdArray => DispatchDeviceId::NdArray,
DeviceType::NdArray => &[DispatchDeviceId::NdArray],
#[cfg(feature = "tch")]
DeviceType::LibTorch => DispatchDeviceId::LibTorch,
DeviceType::LibTorch => &[DispatchDeviceId::LibTorch],
};
#[allow(unreachable_code)] // need to have one backend enabled, so it is reachable
for device in Dispatch::enumerate(type_id) {
devices.push(Device::new(device))
for type_id in type_ids {
for device in Dispatch::enumerate(*type_id) {
devices.push(Device::new(device))
}
}
}
@@ -621,12 +618,7 @@ impl Device {
///
/// Shared by [`Device::wgpu`], [`Device::vulkan`], [`Device::metal`], and
/// [`Device::webgpu`], which differ only in which Cargo feature gates them.
#[cfg(any(
feature = "wgpu",
feature = "vulkan",
feature = "metal",
feature = "webgpu"
))]
#[cfg(feature = "wgpu")]
fn wgpu_device(device_kind: DeviceKind) -> burn_dispatch::devices::WgpuDevice {
use burn_dispatch::devices::WgpuDevice;
match device_kind {
@@ -653,13 +645,14 @@ pub enum DeviceType {
Cuda,
#[cfg(feature = "rocm")]
Rocm,
#[cfg(any(
feature = "wgpu",
feature = "metal",
feature = "vulkan",
feature = "webgpu"
))]
#[cfg(feature = "wgpu")]
Wgpu,
#[cfg(feature = "metal")]
Metal,
#[cfg(feature = "vulkan")]
Vulkan,
#[cfg(feature = "webgpu")]
WebGpu,
#[cfg(feature = "flex")]
Flex,
#[cfg(feature = "ndarray")]

View File

@@ -14,6 +14,13 @@ pub use burn_cubecl::{CubeBackend, tensor::CubeTensor};
pub use cubecl::CubeDim;
pub use cubecl::flex32;
#[cfg(feature = "metal")]
use cubecl::wgpu::MslCompiler;
#[cfg(feature = "vulkan")]
use cubecl::wgpu::SpirvCompiler;
#[cfg(feature = "webgpu")]
use cubecl::wgpu::WgslCompiler;
pub use cubecl::wgpu::{
AutoCompiler, MemoryConfiguration, RuntimeOptions, WgpuDevice, WgpuResource, WgpuRuntime,
WgpuSetup, WgpuStorage, init_device, init_setup, init_setup_async,
@@ -23,12 +30,12 @@ pub mod graphics {
pub use cubecl::wgpu::{AutoGraphicsApi, Dx12, GraphicsApi, Metal, OpenGl, Vulkan, WebGpu};
}
#[cfg(feature = "webgpu")]
pub use cubecl::wgpu::WgslCompiler;
#[cfg(feature = "vulkan")]
pub use cubecl::wgpu::vulkan::VkSpirvCompiler;
#[cfg(feature = "fusion")]
type WgpuInner<C> = burn_fusion::Fusion<CubeBackend<cubecl::wgpu::WgpuRuntime<C>>>;
#[cfg(not(feature = "fusion"))]
type WgpuInner<C> = CubeBackend<cubecl::wgpu::WgpuRuntime<C>>;
/// Tensor backend that uses the wgpu crate for executing GPU compute shaders.
///
/// This backend can target multiple graphics APIs, including:
@@ -38,39 +45,11 @@ pub use cubecl::wgpu::vulkan::VkSpirvCompiler;
/// - [Metal][crate::graphics::Metal] on Apple hardware.
/// - [WebGPU](crate::graphics::WebGpu) on supported browsers and `wasm` runtimes.
///
/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,
/// you have to manually initialize the runtime. For example:
///
/// ```rust, ignore
/// fn custom_init() {
/// let device = Default::default();
/// burn::backend::wgpu::init_setup::<burn::backend::wgpu::graphics::Vulkan>(
/// &device,
/// Default::default(),
/// );
/// }
/// ```
/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
/// It's also possible to use an existing wgpu device, by using `init_device`.
///
/// # Notes
///
/// This version of the wgpu backend uses [burn_fusion] to compile and optimize streams of tensor
/// operations for improved performance.
///
/// You can disable the `fusion` feature flag to remove that functionality, which might be
/// necessary on `wasm` for now.
pub type Wgpu = burn_fusion::Fusion<CubeBackend<cubecl::wgpu::WgpuRuntime>>;
#[cfg(not(feature = "fusion"))]
/// Tensor backend that uses the wgpu crate for executing GPU compute shaders.
///
/// This backend can target multiple graphics APIs, including:
/// - [Vulkan] on Linux, Windows, and Android.
/// - [OpenGL](crate::OpenGl) on Linux, Windows, and Android.
/// - [DirectX 12](crate::Dx12) on Windows.
/// - [Metal] on Apple hardware.
/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes.
/// The selected graphics API is chosen automatically at runtime, and the appropriate shader
/// compiler (WGSL, SPIR-V or MSL) is dispatched via [`AutoCompiler`]. When the target API is
/// known ahead of time, prefer the dedicated `Vulkan`, `WebGpu` or `Metal` backend aliases
/// (enabled by their respective Cargo features), which lock the compiler at compile time and
/// avoid the runtime dispatch.
///
/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,
/// you have to manually initialize the runtime. For example:
@@ -89,24 +68,34 @@ pub type Wgpu = burn_fusion::Fusion<CubeBackend<cubecl::wgpu::WgpuRuntime>>;
///
/// # Notes
///
/// This version of the wgpu backend doesn't use [burn_fusion] to compile and optimize streams of tensor
/// operations.
///
/// You can enable the `fusion` feature flag to add that functionality, which might improve
/// performance.
pub type Wgpu = CubeBackend<cubecl::wgpu::WgpuRuntime>;
/// When the `fusion` feature flag is enabled (the default), this backend uses [burn_fusion] to
/// compile and optimize streams of tensor operations for improved performance. You can disable
/// the `fusion` feature flag to remove that functionality, which might be necessary on `wasm`
/// for now.
pub type Wgpu = WgpuInner<AutoCompiler>;
#[cfg(feature = "vulkan")]
/// Tensor backend that leverages the Vulkan graphics API to execute GPU compute shaders compiled to SPIR-V.
pub type Vulkan = Wgpu;
///
/// This is a specialization of [`Wgpu`] that pins the shader compiler to SPIR-V at compile time,
/// removing the runtime [`AutoCompiler`] dispatch. Enable the `vulkan` feature to use it.
/// Multiple wgpu backend aliases (`Vulkan`, `WebGpu`, `Metal`) can be enabled simultaneously
/// since each is a distinct type parameterized by its own compiler.
#[cfg(feature = "vulkan")]
pub type Vulkan = WgpuInner<SpirvCompiler>;
#[cfg(feature = "webgpu")]
/// Tensor backend that uses the wgpu crate to execute GPU compute shaders written in WGSL.
pub type WebGpu = Wgpu;
///
/// This is a specialization of [`Wgpu`] that pins the shader compiler to WGSL at compile time,
/// removing the runtime [`AutoCompiler`] dispatch. Enable the `webgpu` feature to use it.
#[cfg(feature = "webgpu")]
pub type WebGpu = WgpuInner<WgslCompiler>;
#[cfg(feature = "metal")]
/// Tensor backend that leverages the Metal graphics API to execute GPU compute shaders compiled to MSL.
pub type Metal = Wgpu;
///
/// This is a specialization of [`Wgpu`] that pins the shader compiler to MSL at compile time,
/// removing the runtime [`AutoCompiler`] dispatch. Enable the `metal` feature to use it.
#[cfg(feature = "metal")]
pub type Metal = WgpuInner<MslCompiler>;
#[cfg(test)]
mod tests {