fix: clean rewrite of cutlass_nvfp4_gemm.cu — no more file splicing
Removed dead code from old idx2crd approach. File is now clean: - Source-iterating SF remap kernel with layout_sf(m, k_elem) - Zero-init dest buffers before remap - Proper extern C wrapping
This commit is contained in:
@@ -93,12 +93,17 @@ using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Scale factor remap kernel using CuTe layout operations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Scale factor remap kernel: row-major source -> CUTLASS interleaved layout
|
||||
//
|
||||
// Iterates over source (m, k_group) indices and uses layout_sf(m, k_elem) to
|
||||
// compute the CUTLASS destination offset. CuTe's layout operator() decomposes
|
||||
// flat coordinates into nested sub-coordinates automatically — no idx2crd,
|
||||
// no flatten, no rank inspection.
|
||||
//
|
||||
// K is in element-space for the layout: k_group * SFVecSize.
|
||||
// Iterating groups (not every element) is efficient since all 16 elements
|
||||
// within a group share one SF byte.
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename LayoutSF>
|
||||
@@ -109,17 +114,6 @@ __global__ void remap_sf_to_cutlass_kernel(
|
||||
int MN, int K_sf, // Source dimensions (in SF groups)
|
||||
int SFVecSize // Elements per SF group (16 for NVFP4)
|
||||
) {
|
||||
// SOURCE-ITERATING: for each (m, k_group) in the row-major source,
|
||||
// compute the CUTLASS destination offset using layout_sf(m, k_elem).
|
||||
//
|
||||
// CuTe's layout operator() accepts "natural" coordinates whose rank matches
|
||||
// the top-level mode count (rank-2 here: M and K). It decomposes flat ints
|
||||
// into nested sub-coordinates automatically. No idx2crd, no flatten, no rank inspection.
|
||||
//
|
||||
// The K coordinate is in ELEMENT-SPACE (not group-space):
|
||||
// k_group=3 with SFVecSize=16 -> k_elem=48 in the layout.
|
||||
// All 16 elements within a group share the same SF byte, so we iterate
|
||||
// over groups (not elements) for efficiency — one write per SF value.
|
||||
int src_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = MN * K_sf;
|
||||
if (src_idx >= total) return;
|
||||
@@ -128,14 +122,8 @@ __global__ void remap_sf_to_cutlass_kernel(
|
||||
int k_group = src_idx % K_sf;
|
||||
int k_elem = k_group * SFVecSize;
|
||||
|
||||
// layout_sf(m, k_elem) -> CUTLASS linear index
|
||||
// CuTe handles the nested shape decomposition internally.
|
||||
dst[layout_sf(m, k_elem)] = src[src_idx];
|
||||
}
|
||||
if (m < MN && k_sf_out < K_sf) {
|
||||
dst[dst_idx] = src[m * K_sf + k_sf_out];
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// C API
|
||||
@@ -169,16 +157,12 @@ int cutlass_nvfp4_gemm_run(
|
||||
int K_sf = K / InputSFVectorSize;
|
||||
|
||||
// Allocate CUTLASS-layout buffers, zero-init, and remap scales
|
||||
// Zero-init is critical: CUTLASS pads to tile boundaries (128x64),
|
||||
// so dest is larger than M*K_sf. Unmapped slots must be zero to avoid
|
||||
// garbage values in the GEMM.
|
||||
cutlass::device_memory::allocation<ElementSF> sfa_cutlass(sfa_size);
|
||||
cutlass::device_memory::allocation<ElementSF> sfb_cutlass(sfb_size);
|
||||
cudaMemsetAsync(sfa_cutlass.get(), 0, sfa_size * sizeof(ElementSF), stream);
|
||||
cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream);
|
||||
|
||||
int block = 256;
|
||||
// Grid size based on SOURCE elements (MN * K_sf) since kernel iterates source
|
||||
int sfa_src = M * K_sf;
|
||||
int sfb_src = N * K_sf;
|
||||
remap_sf_to_cutlass_kernel<<<(sfa_src + block - 1) / block, block, 0, stream>>>(
|
||||
|
||||
Reference in New Issue
Block a user