Test: pass scales directly to CUTLASS (no remap) to diagnose layout issue

This commit is contained in:
2026-05-14 10:23:02 +00:00
parent a272bc49b0
commit 84becfac93

View File

@@ -178,32 +178,33 @@ int cutlass_nvfp4_gemm_run(
LayoutSFA layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
LayoutSFB layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
// For now, pass scale factors directly (no remap) to test if the kernel works
// The CUTLASS layout may expect interleaved data, but let's see what happens
// with simple row-major. If TMA reads garbage, we'll get NaN or wrong values
// rather than crashes (since we verified small sizes work).
//
// TODO: implement proper interleaved layout remap
// Temporary: allocate CUTLASS-layout buffers and do a simple copy
int K_sf = K / InputSFVectorSize;
using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA;
using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB;
using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF;
int sfa_size = cute::size(layout_SFA);
int sfb_size = cute::size(layout_SFB);
int K_sf = K / InputSFVectorSize;
cutlass::device_memory::allocation<ElementSF> sfa_cutlass(sfa_size);
cutlass::device_memory::allocation<ElementSF> sfb_cutlass(sfb_size);
// Remap scales from simple row-major to CUTLASS interleaved layout
int block = 256;
remap_sf_to_cutlass_layout_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf);
remap_sf_to_cutlass_layout_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf);
// Pass scale factors directly — the layout_SFA/layout_SFB tell CUTLASS
// how to index into the data. Our data is (MN, K//16) row-major but
// CUTLASS expects its own interleaved layout.
// For initial testing, pass directly and check for NaN/wrong values.
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K, 1},
{
static_cast<const ArrayElementA*>(A_ptr), stride_A,
static_cast<const ArrayElementB*>(B_ptr), stride_B,
sfa_cutlass.get(), layout_SFA,
sfb_cutlass.get(), layout_SFB
static_cast<const ElementSF*>(SFA_ptr), layout_SFA,
static_cast<const ElementSF*>(SFB_ptr), layout_SFB
},
{
{ alpha, beta },