fix: rewrite SF remap as forward mapping (source→dst)
- Iterate over source indices (MN * K_sf) instead of dst indices - Use layout_sf forward mapping: layout_sf(make_coord(mn, k_sf*16)) - No more idx2crd reverse extraction or stride-0 ambiguity - Cleaner, less error-prone, blog-compatible
This commit is contained in:
@@ -124,39 +124,40 @@ __global__ void remap_sf_to_cutlass_kernel(
|
||||
int MN, int K_sf, // Source dimensions (in SF groups)
|
||||
bool col_major_src = false // true if source is (K_sf, MN) row-major
|
||||
) {
|
||||
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = cute::cosize(layout_sf);
|
||||
if (dst_idx >= total) return;
|
||||
// Forward-mapping approach: iterate over source indices (mn, k_sf),
|
||||
// compute the CUTLASS dst index via layout_sf forward mapping.
|
||||
// k_sf * InputSFVectorSize converts from SF-group index to logical K element
|
||||
// coordinate, which is what the layout expects.
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = MN * K_sf;
|
||||
if (tid >= total) return;
|
||||
|
||||
auto coord = cute::idx2crd(dst_idx, layout_sf.shape(), layout_sf.stride());
|
||||
auto flat = cute::flatten(coord);
|
||||
int mn, k_sf_val, src_idx;
|
||||
|
||||
constexpr int R = cute::rank_v<decltype(flat)>;
|
||||
|
||||
int mn = 0;
|
||||
int k_sf = 0;
|
||||
|
||||
if constexpr (R >= 6) {
|
||||
// K-major SF atom: Shape<Shape<_32,_4>, Shape<SFVecSize,_4>>
|
||||
// Stride<Stride<_16,_4>, Stride<_0,_1>>
|
||||
// SFA tiled as make_shape(M,K,...), SFB as make_shape(N,K,...)
|
||||
// First flattened group is M/N: ((32,4), tile_mn)
|
||||
mn =
|
||||
int(cute::get<0>(flat)) +
|
||||
32 * int(cute::get<1>(flat)) +
|
||||
128 * int(cute::get<2>(flat));
|
||||
|
||||
// Second flattened group is K: (SFVecSize, 4, tile_k)
|
||||
// get<3> is the stride-0 k-within-SF-vector coordinate — ignore it.
|
||||
k_sf =
|
||||
int(cute::get<4>(flat)) +
|
||||
4 * int(cute::get<5>(flat));
|
||||
if (col_major_src) {
|
||||
// source is row-major (K_sf, MN), e.g. SFB stored as (K_sf, N)
|
||||
k_sf_val = tid / MN;
|
||||
mn = tid % MN;
|
||||
src_idx = tid;
|
||||
} else {
|
||||
// source is row-major (MN, K_sf), e.g. SFA stored as (M, K_sf)
|
||||
mn = tid / K_sf;
|
||||
k_sf_val = tid % K_sf;
|
||||
src_idx = tid;
|
||||
}
|
||||
|
||||
if (mn < MN && k_sf < K_sf) {
|
||||
dst[dst_idx] = col_major_src
|
||||
? src[k_sf * MN + mn] // SFB source: (K_sf, N)
|
||||
: src[mn * K_sf + k_sf]; // SFA source: (M, K_sf)
|
||||
// Use layout forward mapping: source (mn, k_sf*16) -> dst_idx
|
||||
constexpr int LayoutRank = cute::rank_v<decltype(layout_sf.shape())>;
|
||||
int dst_idx = 0;
|
||||
|
||||
if constexpr (LayoutRank == 2) {
|
||||
dst_idx = layout_sf(cute::make_coord(mn, k_sf_val * InputSFVectorSize));
|
||||
} else if constexpr (LayoutRank == 3) {
|
||||
dst_idx = layout_sf(cute::make_coord(mn, k_sf_val * InputSFVectorSize, 0));
|
||||
}
|
||||
|
||||
if (dst_idx >= 0 && dst_idx < cute::cosize(layout_sf)) {
|
||||
dst[dst_idx] = src[src_idx];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -197,9 +198,11 @@ int cutlass_nvfp4_gemm_run(
|
||||
cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream);
|
||||
|
||||
int block = 256;
|
||||
remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>(
|
||||
int sfa_src_total = M * K_sf;
|
||||
int sfb_src_total = N * K_sf;
|
||||
remap_sf_to_cutlass_kernel<<<(sfa_src_total + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, false);
|
||||
remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>(
|
||||
remap_sf_to_cutlass_kernel<<<(sfb_src_total + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf, true);
|
||||
|
||||
typename Gemm::Arguments arguments {
|
||||
@@ -260,7 +263,8 @@ extern "C" int cutlass_nvfp4_prepack_sfb_run(
|
||||
cudaMemsetAsync(static_cast<ElementSF*>(SFB_cutlass_ptr), 0, sfb_size * sizeof(ElementSF), stream);
|
||||
|
||||
int block = 256;
|
||||
remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>(
|
||||
int sfb_src_total = N * K_sf;
|
||||
remap_sf_to_cutlass_kernel<<<(sfb_src_total + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFB_ptr),
|
||||
static_cast<ElementSF*>(SFB_cutlass_ptr),
|
||||
layout_SFB,
|
||||
@@ -303,7 +307,8 @@ extern "C" int cutlass_nvfp4_gemm_run_prepacked_sfb(
|
||||
cudaMemsetAsync(sfa_cutlass.get(), 0, sfa_size * sizeof(ElementSF), stream);
|
||||
|
||||
int block = 256;
|
||||
remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>(
|
||||
int sfa_src_total = M * K_sf;
|
||||
remap_sf_to_cutlass_kernel<<<(sfa_src_total + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, false);
|
||||
|
||||
typename Gemm::Arguments arguments {
|
||||
|
||||
Reference in New Issue
Block a user