fix: clean rewrite of cutlass_nvfp4_gemm.cu — no more file splicing

Removed dead code from old idx2crd approach. File is now clean:
- Source-iterating SF remap kernel with layout_sf(m, k_elem)
- Zero-init dest buffers before remap
- Proper extern C wrapping
This commit is contained in:
2026-05-14 15:31:03 +00:00
parent 196ee37fdb
commit 593ae998f8

View File

@@ -93,12 +93,17 @@ using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Scale factor remap kernel using CuTe layout operations
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
// Scale factor remap kernel: row-major source -> CUTLASS interleaved layout
//
// Iterates over source (m, k_group) indices and uses layout_sf(m, k_elem) to
// compute the CUTLASS destination offset. CuTe's layout operator() decomposes
// flat coordinates into nested sub-coordinates automatically — no idx2crd,
// no flatten, no rank inspection.
//
// K is in element-space for the layout: k_group * SFVecSize.
// Iterating groups (not every element) is efficient since all 16 elements
// within a group share one SF byte.
/////////////////////////////////////////////////////////////////////////////////////////////////
template<typename LayoutSF>
@@ -109,17 +114,6 @@ __global__ void remap_sf_to_cutlass_kernel(
int MN, int K_sf, // Source dimensions (in SF groups)
int SFVecSize // Elements per SF group (16 for NVFP4)
) {
// SOURCE-ITERATING: for each (m, k_group) in the row-major source,
// compute the CUTLASS destination offset using layout_sf(m, k_elem).
//
// CuTe's layout operator() accepts "natural" coordinates whose rank matches
// the top-level mode count (rank-2 here: M and K). It decomposes flat ints
// into nested sub-coordinates automatically. No idx2crd, no flatten, no rank inspection.
//
// The K coordinate is in ELEMENT-SPACE (not group-space):
// k_group=3 with SFVecSize=16 -> k_elem=48 in the layout.
// All 16 elements within a group share the same SF byte, so we iterate
// over groups (not elements) for efficiency — one write per SF value.
int src_idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = MN * K_sf;
if (src_idx >= total) return;
@@ -128,14 +122,8 @@ __global__ void remap_sf_to_cutlass_kernel(
int k_group = src_idx % K_sf;
int k_elem = k_group * SFVecSize;
// layout_sf(m, k_elem) -> CUTLASS linear index
// CuTe handles the nested shape decomposition internally.
dst[layout_sf(m, k_elem)] = src[src_idx];
}
if (m < MN && k_sf_out < K_sf) {
dst[dst_idx] = src[m * K_sf + k_sf_out];
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// C API
@@ -169,16 +157,12 @@ int cutlass_nvfp4_gemm_run(
int K_sf = K / InputSFVectorSize;
// Allocate CUTLASS-layout buffers, zero-init, and remap scales
// Zero-init is critical: CUTLASS pads to tile boundaries (128x64),
// so dest is larger than M*K_sf. Unmapped slots must be zero to avoid
// garbage values in the GEMM.
cutlass::device_memory::allocation<ElementSF> sfa_cutlass(sfa_size);
cutlass::device_memory::allocation<ElementSF> sfb_cutlass(sfb_size);
cudaMemsetAsync(sfa_cutlass.get(), 0, sfa_size * sizeof(ElementSF), stream);
cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream);
int block = 256;
// Grid size based on SOURCE elements (MN * K_sf) since kernel iterates source
int sfa_src = M * K_sf;
int sfb_src = N * K_sf;
remap_sf_to_cutlass_kernel<<<(sfa_src + block - 1) / block, block, 0, stream>>>(