fix: rewrite SF remap kernel — source-iterating with layout_sf(m, k_elem)
Ripped out idx2crd + flatten + get<> approach entirely. New kernel iterates over source indices (m, k_group) and uses layout_sf(m, k_elem) to compute the CUTLASS destination offset. CuTe handles nested shape decomposition internally — no rank inspection needed. K coordinate is in element-space (k_group * SFVecSize) as the layout expects. Iterates over groups (not every element) since all 16 elements within a group share one SF byte — avoids 16x redundant writes. Grid size based on source count (MN * K_sf), not dest buffer size.
This commit is contained in:
@@ -98,59 +98,40 @@ using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Scale factor remap kernel using CuTe layout operations
|
||||
// Scale factor remap kernel: row-major source -> CUTLASS interleaved layout
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename LayoutSF>
|
||||
__global__ void remap_sf_to_cutlass_kernel(
|
||||
const cutlass::float_ue4m3_t* __restrict__ src,
|
||||
cutlass::float_ue4m3_t* __restrict__ dst,
|
||||
LayoutSF layout_sf,
|
||||
int MN, int K_sf
|
||||
const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) row-major
|
||||
cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS interleaved layout (zero-initialized)
|
||||
LayoutSF layout_sf, // CuTe layout for dst
|
||||
int MN, int K_sf, // Source dimensions (in SF groups)
|
||||
int SFVecSize // Elements per SF group (16 for NVFP4)
|
||||
) {
|
||||
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = cute::size(layout_sf);
|
||||
if (dst_idx >= total) return;
|
||||
// SOURCE-ITERATING: for each (m, k_group) in the row-major source,
|
||||
// compute the CUTLASS destination offset using layout_sf(m, k_elem).
|
||||
//
|
||||
// CuTe's layout operator() accepts "natural" coordinates whose rank matches
|
||||
// the top-level mode count (rank-2 here: M and K). It decomposes flat ints
|
||||
// into nested sub-coordinates automatically. No idx2crd, no flatten, no rank inspection.
|
||||
//
|
||||
// The K coordinate is in ELEMENT-SPACE (not group-space):
|
||||
// k_group=3 with SFVecSize=16 -> k_elem=48 in the layout.
|
||||
// All 16 elements within a group share the same SF byte, so we iterate
|
||||
// over groups (not elements) for efficiency — one write per SF value.
|
||||
int src_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = MN * K_sf;
|
||||
if (src_idx >= total) return;
|
||||
|
||||
auto coord = cute::idx2crd(dst_idx, layout_sf.shape(), layout_sf.stride());
|
||||
auto flat = cute::flatten(coord);
|
||||
|
||||
constexpr int flat_rank = cute::rank_v<decltype(flat)>;
|
||||
|
||||
// Debug: print flat_rank and first few coordinates (only for idx 0)
|
||||
if (dst_idx == 0) {
|
||||
printf("[remap] flat_rank=%d, MN=%d, K_sf=%d, total=%d\n", flat_rank, MN, K_sf, total);
|
||||
printf("[remap] layout shape: ");
|
||||
cute::print(layout_sf.shape());
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
int m, k_sf_out;
|
||||
|
||||
if constexpr (flat_rank == 6) {
|
||||
// Full nested: (inner_M, sub_M, tile_M, inner_K, sub_K, tile_K)
|
||||
int inner_m = cute::get<0>(flat);
|
||||
int sub_m = cute::get<1>(flat);
|
||||
int tile_m = cute::get<2>(flat);
|
||||
// get<3> = inner_k (within one SF group — same byte)
|
||||
int sub_k = cute::get<4>(flat);
|
||||
int tile_k = cute::get<5>(flat);
|
||||
|
||||
m = tile_m * 128 + inner_m * 4 + sub_m;
|
||||
k_sf_out = tile_k * 4 + sub_k;
|
||||
} else if constexpr (flat_rank == 4) {
|
||||
int inner_m = cute::get<0>(flat);
|
||||
int sub_m = cute::get<1>(flat);
|
||||
int sub_k = cute::get<2>(flat);
|
||||
int tile_k = cute::get<3>(flat);
|
||||
|
||||
m = inner_m * 4 + sub_m;
|
||||
k_sf_out = tile_k * 4 + sub_k;
|
||||
} else {
|
||||
m = cute::get<0>(flat);
|
||||
k_sf_out = cute::get<1>(flat);
|
||||
}
|
||||
int m = src_idx / K_sf;
|
||||
int k_group = src_idx % K_sf;
|
||||
int k_elem = k_group * SFVecSize;
|
||||
|
||||
// layout_sf(m, k_elem) -> CUTLASS linear index
|
||||
// CuTe handles the nested shape decomposition internally.
|
||||
dst[layout_sf(m, k_elem)] = src[src_idx];
|
||||
}
|
||||
if (m < MN && k_sf_out < K_sf) {
|
||||
dst[dst_idx] = src[m * K_sf + k_sf_out];
|
||||
}
|
||||
@@ -197,11 +178,13 @@ int cutlass_nvfp4_gemm_run(
|
||||
cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream);
|
||||
|
||||
int block = 256;
|
||||
// Grid size based on DEST (CUTLASS buffer) size since kernel iterates dest
|
||||
remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf);
|
||||
remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf);
|
||||
// Grid size based on SOURCE elements (MN * K_sf) since kernel iterates source
|
||||
int sfa_src = M * K_sf;
|
||||
int sfb_src = N * K_sf;
|
||||
remap_sf_to_cutlass_kernel<<<(sfa_src + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, InputSFVectorSize);
|
||||
remap_sf_to_cutlass_kernel<<<(sfb_src + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf, InputSFVectorSize);
|
||||
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
|
||||
Reference in New Issue
Block a user