mirror of
https://github.com/tracel-ai/burn.git
synced 2026-05-31 19:49:48 +09:00
perf(flex): bulk-copy inner-contiguous run in to_contiguous (#5019)
This commit is contained in:
@@ -377,6 +377,11 @@ impl FlexTensor {
|
||||
copy_2d_tiled(
|
||||
&mut dst, src, offset, shape[0], shape[1], strides[0], strides[1],
|
||||
);
|
||||
} else if all_positive && strides.last().copied() == Some(1) {
|
||||
// Since the inner stride is 1, we can bulk-copy
|
||||
// the contiguous inner run for each outer index.
|
||||
unsafe { dst.set_len(n) };
|
||||
copy_inner_contiguous_run(&mut dst, src, offset, shape, strides);
|
||||
} else {
|
||||
// General fallback: covers negative strides (flipped
|
||||
// tensors) and ND layouts that can't collapse to ≤2D.
|
||||
@@ -546,6 +551,90 @@ fn collapse_for_copy(shape: &[usize], strides: &[isize]) -> CollapsedLayout {
|
||||
out
|
||||
}
|
||||
|
||||
/// Copy a strided source into a contiguous destination by treating the
|
||||
/// layout as a set of strided outer dims wrapping a contiguous inner run.
|
||||
///
|
||||
/// Precondition: strides are all positive and the innermost stride is 1
|
||||
#[inline]
|
||||
fn copy_inner_contiguous_run<E: Copy>(
|
||||
dst: &mut [E],
|
||||
src: &[E],
|
||||
offset: isize,
|
||||
shape: &[usize],
|
||||
strides: &[isize],
|
||||
) {
|
||||
debug_assert!(!shape.is_empty());
|
||||
debug_assert_eq!(*strides.last().unwrap(), 1, "innermost stride must be 1");
|
||||
|
||||
// We want the maximal trailing run that is contiguous in src.
|
||||
// Walk inward-out, a dim dim_idx extends the contiguous run if AND ONLY IF ("iff")
|
||||
// strides[dim_idx] == product of inner shapes seen so far.
|
||||
let mut expected_stride: isize = 1;
|
||||
let mut split = shape.len(); // first OUTER index (dims [split..] are inner)
|
||||
while split > 0 {
|
||||
let dim_idx = split - 1;
|
||||
if strides[dim_idx] == expected_stride {
|
||||
expected_stride = expected_stride
|
||||
.checked_mul(shape[dim_idx] as isize)
|
||||
.expect("contiguous run length overflows isize");
|
||||
split -= 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let outer_shape = &shape[..split];
|
||||
let outer_strides = &strides[..split];
|
||||
let inner_run_length: usize = shape[split..].iter().product();
|
||||
|
||||
// Whole tensor is one contiguous run! We can bulk-copy the whole thing.
|
||||
if outer_shape.is_empty() {
|
||||
let src_index = offset as usize;
|
||||
dst[..inner_run_length].copy_from_slice(&src[src_index..src_index + inner_run_length]);
|
||||
return;
|
||||
}
|
||||
|
||||
if inner_run_length == 0 || outer_shape.contains(&0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Walk the outer index space in row-major order via an odometer:
|
||||
// each iteration ticks the rightmost counter, and so on overflow we
|
||||
// reset it and carry +1 into the dim to its left.
|
||||
// Row-major is what we want because dst is contiguous (consecutive outer-index
|
||||
// tuples land in consecutive inner_run_length slots, so we just bump
|
||||
// dst_position each step instead of recomputing it)
|
||||
// We use this instead of N nested for loops because the rank isn't
|
||||
// known at compile time.
|
||||
let mut counter = [0usize; COLLAPSE_MAX_RANK];
|
||||
let outer_dim_count = outer_shape.len();
|
||||
let mut dst_position = 0usize;
|
||||
|
||||
loop {
|
||||
let mut src_start = offset;
|
||||
for dim_idx in 0..outer_dim_count {
|
||||
src_start += counter[dim_idx] as isize * outer_strides[dim_idx];
|
||||
}
|
||||
let src_index = src_start as usize;
|
||||
dst[dst_position..dst_position + inner_run_length]
|
||||
.copy_from_slice(&src[src_index..src_index + inner_run_length]);
|
||||
dst_position += inner_run_length;
|
||||
|
||||
let mut dim_idx = outer_dim_count;
|
||||
loop {
|
||||
if dim_idx == 0 {
|
||||
return;
|
||||
}
|
||||
dim_idx -= 1;
|
||||
counter[dim_idx] += 1;
|
||||
if counter[dim_idx] < outer_shape[dim_idx] {
|
||||
break;
|
||||
}
|
||||
counter[dim_idx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tiled 2D copy from a strided source into a contiguous destination.
|
||||
/// The loop nesting is chosen so the innermost read walks whichever
|
||||
/// source stride is smaller, which keeps the hot loop in cache even
|
||||
@@ -763,6 +852,43 @@ mod tests {
|
||||
assert_eq!(values, expected.as_slice());
|
||||
}
|
||||
|
||||
/// Regression: the ShuffleNet channel shuffle (reshape -> permute([0,2,1,3,4]) -> reshape)
|
||||
/// collapses to a 3D layout with stride-1 innermost.
|
||||
/// Before this fix that went into the StridedIter scalar walk in the final else in copy_contiguous
|
||||
/// Now, we use a new inner-contiguous branch that bulk-copies the trailing run
|
||||
#[test]
|
||||
fn test_to_contiguous_3d_inner_stride1_matches_naive() {
|
||||
let dims = [1, 2, 4, 3, 3]; // [N, G, C, H, W]
|
||||
let n: usize = dims.iter().product();
|
||||
let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
|
||||
let t = FlexTensor::from_data(TensorData::new(data.clone(), dims.to_vec()));
|
||||
let permuted = t.permute(&[0, 2, 1, 3, 4]); // swap G and C from the original dims in the permute
|
||||
assert!(!permuted.is_contiguous());
|
||||
|
||||
let contiguous_data = permuted.to_contiguous();
|
||||
assert!(contiguous_data.is_contiguous());
|
||||
assert_eq!(contiguous_data.shape().to_vec(), vec![1, 4, 2, 3, 3]);
|
||||
|
||||
// Now we manually do a strided walk to be able to check that against a naive solution.
|
||||
// output[0, c, g, h, w] = input[0, g, c, h, w].
|
||||
let mut expected = Vec::with_capacity(n);
|
||||
for c in 0..4 {
|
||||
for g in 0..2 {
|
||||
for h in 0..3 {
|
||||
for w in 0..3 {
|
||||
let idx = g * 36 + c * 9 + h * 3 + w;
|
||||
expected.push(data[idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the naive solution matches (the manual strided walk).
|
||||
let result_data = contiguous_data.into_data();
|
||||
let values = result_data.as_slice::<f32>().unwrap();
|
||||
assert_eq!(values, expected.as_slice());
|
||||
}
|
||||
|
||||
/// Exercise the `row_stride > col_stride` branch of the 2D tiled
|
||||
/// copy (the ConvNeXt case hits the other branch).
|
||||
#[test]
|
||||
@@ -832,7 +958,7 @@ mod tests {
|
||||
// Both point to same data
|
||||
assert!(core::ptr::eq(
|
||||
tensor.bytes().as_ptr(),
|
||||
cloned.bytes().as_ptr()
|
||||
cloned.bytes().as_ptr(),
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user