diff --git a/examples/cuda_matmul/src/main.rs b/examples/cuda_matmul/src/main.rs index 5191d710..647f1f4b 100644 --- a/examples/cuda_matmul/src/main.rs +++ b/examples/cuda_matmul/src/main.rs @@ -13,9 +13,9 @@ fn main() { .with(fmt::layer()) .init(); - let m = (2_usize).pow(6); - let n = (2_usize).pow(6); - let k = (2_usize).pow(6); + let n = (2_usize).pow(4); + let m = (2_usize).pow(3); + let k = (2_usize).pow(2); info!(m); info!(n); @@ -41,42 +41,25 @@ fn main() { // Generate random input tensors based on dimensions let mut rng = rand::rng(); let a_data: Vec = (0..(m * k)).map(|_| rng.random_range(0.0..1.0)).collect(); - let b_data_row_major: Vec = (0..(k * n)).map(|_| rng.random_range(0.0..1.0)).collect(); + let b_data: Vec = (0..(k * n)).map(|_| rng.random_range(0.0..1.0)).collect(); // Create candle tensors on CPU for verification let device = Device::Cpu; // A is row-major (m x k) let a_candle = Tensor::from_vec(a_data.clone(), (m, k), &device).unwrap(); - debug!("a_candle: {:?}", a_candle); - trace!("a_candle matrix:\n{}", a_candle); + debug!("A input matrix [{} x {}]:\n{}", m, k, a_candle); - // B needs to be column-major for luminal, so create it as row-major (k x n) - // then transpose to get column-major layout - let b_candle_row_major = Tensor::from_vec(b_data_row_major.clone(), (k, n), &device).unwrap(); - debug!("b_candle: {:?}", b_candle_row_major); - trace!("b_candle:\n{}", b_candle_row_major); + // B is row-major (k x n) + let b_candle = Tensor::from_vec(b_data.clone(), (k, n), &device).unwrap(); + debug!("B input matrix [{} x {}]:\n{}", k, n, b_candle); - // For luminal, we need B in column-major format, which means we store it as transposed - // Convert row-major (k x n) to column-major by transposing to (n x k) then reading as flat - let b_candle_transposed = b_candle_row_major.t().unwrap().contiguous().unwrap(); - debug!("b_candle.t(): {:?}", b_candle_transposed); - trace!("b_candle.t():\n{}", b_candle_transposed); - - let b_data_col_major = b_candle_transposed - .flatten_all() - .unwrap() - .to_vec1::() - .unwrap(); - - // Set input tensors for luminal - // A is row-major as is + // Set input tensors for luminal - both row-major rt.set_data(a, a_data.clone()); - // B is in column-major format - rt.set_data(b, b_data_col_major.clone()); + rt.set_data(b, b_data.clone()); // Use search limit of 1 to get HostMatmul like the test - rt = cx.search(rt, 5); + rt = cx.search(rt, 1); fs::write("llir.dot", (&rt.llir_graph).to_dot().unwrap()).unwrap(); @@ -94,8 +77,14 @@ fn main() { debug!("Average value: {}", avg); // Compute matmul using candle for verification - // a_candle is (m x k) row-major, b_candle_row_major is (k x n) row-major - let candle_result = a_candle.matmul(&b_candle_row_major).unwrap(); + // a_candle is (m x k) row-major, b_candle is (k x n) row-major + let candle_result = a_candle.matmul(&b_candle).unwrap(); + debug!("C candle result [{} x {}]:\n{}", m, n, candle_result); + + // Create luminal result as tensor for debug display + let luminal_result_tensor = Tensor::from_vec(luminal_result_data.clone(), (m, n), &device).unwrap(); + debug!("C luminal result [{} x {}]:\n{}", m, n, luminal_result_tensor); + let candle_result_flat = candle_result .to_vec2::() .unwrap()