Fix: use idx2crd instead of get_coord for CuTe layout coordinate lookup
This commit is contained in:
@@ -120,11 +120,15 @@ __global__ void remap_sf_to_cutlass_kernel(
|
||||
// The key: the layout was created by tile_to_shape(SfAtom{}, make_shape(MN, K), Step<_2,_1>{})
|
||||
// So the coordinate (c0, c1) corresponds to row c0 and K-group c1 in the original tensor.
|
||||
|
||||
auto coord = layout_sf.get_coord(idx);
|
||||
// idx2crd converts a linear index to a logical coordinate in the layout's space
|
||||
auto coord = cute::idx2crd(idx, layout_sf.shape(), layout_sf.stride());
|
||||
|
||||
// Extract the (m, k) pair from the coordinate tuple
|
||||
int m = cute::get<0>(cute::flatten(coord)); // Row index
|
||||
int k = cute::get<1>(cute::flatten(coord)); // K-group index
|
||||
// The coordinate is a nested tuple. For a 2D layout it's (c0, c1)
|
||||
// where c0 = row/M index and c1 = col/K-group index.
|
||||
// Flatten the nested tuple to extract the two logical coordinates.
|
||||
auto flat = cute::flatten(coord);
|
||||
int m = cute::get<0>(flat);
|
||||
int k = cute::get<1>(flat);
|
||||
|
||||
if (m < MN && k < K_sf) {
|
||||
dst[idx] = src[m * K_sf + k];
|
||||
|
||||
Reference in New Issue
Block a user