mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Register cuda_lite only under "cuda_lite", not "cuda" or "gpu"
Avoids confusion with cuda_heavy. Auto-detection now returns "cuda_lite" for CUDA tensors. Test scripts updated to match. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -15,7 +15,7 @@ pub struct CudaLiteDynBackend {
|
||||
}
|
||||
|
||||
impl DynBackend for CudaLiteDynBackend {
|
||||
fn name(&self) -> &str { "cuda" }
|
||||
fn name(&self) -> &str { "cuda_lite" }
|
||||
fn device_type(&self) -> &str { "cuda" }
|
||||
|
||||
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, _dtype: DType) {
|
||||
@@ -58,9 +58,7 @@ fn cuda_lite_factory(graph: &mut Graph, args: BackendCompileArgs) -> Result<Box<
|
||||
)
|
||||
}
|
||||
|
||||
/// Register under `"cuda_lite"`, `"cuda"`, and `"gpu"`.
|
||||
/// Register under `"cuda_lite"`.
|
||||
pub fn register() {
|
||||
register_backend("cuda_lite", cuda_lite_factory);
|
||||
register_backend("cuda", cuda_lite_factory);
|
||||
register_backend("gpu", cuda_lite_factory);
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
echo ""
|
||||
echo "--- 2a: CUDA ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda_lite uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
|
||||
@@ -14,7 +14,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend
|
||||
echo "Step 3: Running pytest with CUDA backend..."
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda_lite uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -16,7 +16,7 @@ def _detect_backend(example_inputs):
|
||||
if env_backend:
|
||||
return env_backend
|
||||
device = example_inputs[0].device if example_inputs else torch.device("cpu")
|
||||
return "cuda" if device.type == "cuda" else "native"
|
||||
return "cuda_lite" if device.type == "cuda" else "native"
|
||||
|
||||
|
||||
def _collect_weight_pointers(weights, backend):
|
||||
|
||||
Reference in New Issue
Block a user