fix: 4 kernel compilation fixes for packed FP4

1. sizeof_bits_v→sizeof_bits<T>::value (our CUTLASS lacks C++17 _v form)
2. reinterpret_cast<uint8_t*> for TMA copy and UMMA desc calls
   (smem_a returns float_e2m1_t* but templates expect uint8_t*)
3. kNumChunks extended to 4 (packed FP4 halved SMEM, need more chunks)
4. No code changes to PatternVisitor — all fixes at call sites
This commit is contained in:
2026-05-11 23:17:51 +00:00
parent 49e5646b42
commit d6551617c0

View File

@@ -172,9 +172,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// mxf4nvf4 reads packed FP4 from SMEM (2 values per byte), NOT unpacked.
// _unpacksmem_t was for mxf8f6f4 which reads FP4 as FP8 (1 byte/element).
// For mxf4nvf4: sizeof_bits = 4, SMEM stride = BLOCK_K/2 bytes, UMMA_K = 64.
static_assert(cutlass::sizeof_bits_v<a_dtype_t> == 4,
static_assert(cutlass::sizeof_bits<a_dtype_t>::value == 4,
"mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM");
static_assert(cutlass::sizeof_bits_v<b_dtype_t> == 4,
static_assert(cutlass::sizeof_bits<b_dtype_t>::value == 4,
"mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM");
// MMA configs
@@ -741,7 +741,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// TMA copy tokens and SFA, then arrive at full barrier
if (cute::elect_one_sync()) {
tma::copy<BLOCK_K / 2, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(
tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2);
tensor_map_a_ptr, full_barriers[stage_idx], reinterpret_cast<uint8_t*>(smem_a[stage_idx]), k_idx, m_idx, 2);
tma::copy<SF_BLOCK_M, 1, 0>(
tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2);
if (is_leader_cta) {
@@ -785,7 +785,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
if (cute::elect_one_sync()) {
// NVFP4: weights are packed E2M1, BLOCK_K elements = BLOCK_K/2 bytes
tma::copy<BLOCK_K / 2, LOAD_BLOCK_N, kSwizzleBMode, uint8_t>(
tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2);
tensor_map_b_ptr, full_barriers[stage_idx], reinterpret_cast<uint8_t*>(smem_b[stage_idx]), k_idx, n_idx, 2);
tma::copy<BLOCK_N, 1, 0>(
tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2);
if (is_leader_cta) {
@@ -815,8 +815,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
// NVFP4: UMMA descriptors use packed byte dimensions (BLOCK_K/2, uint8_t)
// because sizeof(float_e2m1_t)=1 but real stride is BLOCK_K/2 bytes per K-row
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_M, BLOCK_K / 2, kSwizzleAMode, false, uint8_t>(smem_a[0], 0, 0);
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_N, BLOCK_K / 2, kSwizzleBMode, false, uint8_t>(smem_b[0], 0, 0);
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_M, BLOCK_K / 2, kSwizzleAMode, false, uint8_t>(reinterpret_cast<uint8_t*>(smem_a[0]), 0, 0);
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_N, BLOCK_K / 2, kSwizzleBMode, false, uint8_t>(reinterpret_cast<uint8_t*>(smem_b[0]), 0, 0);
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
@@ -1307,10 +1307,15 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
constexpr uint32_t kNumChunkSlots = 3;
constexpr uint32_t kNumMaxRegistersForBuffer = 128;
// NOTES: either 1 or 2 chunks for simplicity
// NOTES: 1, 2, or 4 chunks depending on smem/register constraints
// NOTES: Restrict on both smem and register
constexpr uint32_t kNumChunks =
kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE and kHidden <= 32 * kNumMaxRegistersForBuffer ? 1 : 2;
kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE
and kHidden <= 32 * kNumMaxRegistersForBuffer
? 1
: kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / 2 <= SMEM_BEFORE_BARRIER_SIZE
? 2
: 4;
constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks;
constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4);
constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32;