Test: pass scales directly to CUTLASS (no remap) to diagnose layout issue
This commit is contained in:
@@ -178,32 +178,33 @@ int cutlass_nvfp4_gemm_run(
|
||||
LayoutSFA layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
LayoutSFB layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
|
||||
// For now, pass scale factors directly (no remap) to test if the kernel works
|
||||
// The CUTLASS layout may expect interleaved data, but let's see what happens
|
||||
// with simple row-major. If TMA reads garbage, we'll get NaN or wrong values
|
||||
// rather than crashes (since we verified small sizes work).
|
||||
//
|
||||
// TODO: implement proper interleaved layout remap
|
||||
|
||||
// Temporary: allocate CUTLASS-layout buffers and do a simple copy
|
||||
int K_sf = K / InputSFVectorSize;
|
||||
|
||||
using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA;
|
||||
using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB;
|
||||
using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF;
|
||||
|
||||
int sfa_size = cute::size(layout_SFA);
|
||||
int sfb_size = cute::size(layout_SFB);
|
||||
int K_sf = K / InputSFVectorSize;
|
||||
|
||||
cutlass::device_memory::allocation<ElementSF> sfa_cutlass(sfa_size);
|
||||
cutlass::device_memory::allocation<ElementSF> sfb_cutlass(sfb_size);
|
||||
|
||||
// Remap scales from simple row-major to CUTLASS interleaved layout
|
||||
int block = 256;
|
||||
remap_sf_to_cutlass_layout_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf);
|
||||
remap_sf_to_cutlass_layout_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf);
|
||||
|
||||
// Pass scale factors directly — the layout_SFA/layout_SFB tell CUTLASS
|
||||
// how to index into the data. Our data is (MN, K//16) row-major but
|
||||
// CUTLASS expects its own interleaved layout.
|
||||
// For initial testing, pass directly and check for NaN/wrong values.
|
||||
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K, 1},
|
||||
{
|
||||
static_cast<const ArrayElementA*>(A_ptr), stride_A,
|
||||
static_cast<const ArrayElementB*>(B_ptr), stride_B,
|
||||
sfa_cutlass.get(), layout_SFA,
|
||||
sfb_cutlass.get(), layout_SFB
|
||||
static_cast<const ElementSF*>(SFA_ptr), layout_SFA,
|
||||
static_cast<const ElementSF*>(SFB_ptr), layout_SFB
|
||||
},
|
||||
{
|
||||
{ alpha, beta },
|
||||
|
||||
Reference in New Issue
Block a user