fix: 4 kernel compilation fixes for packed FP4
1. sizeof_bits_v→sizeof_bits<T>::value (our CUTLASS lacks C++17 _v form) 2. reinterpret_cast<uint8_t*> for TMA copy and UMMA desc calls (smem_a returns float_e2m1_t* but templates expect uint8_t*) 3. kNumChunks extended to 4 (packed FP4 halved SMEM, need more chunks) 4. No code changes to PatternVisitor — all fixes at call sites
This commit is contained in:
@@ -172,9 +172,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
// mxf4nvf4 reads packed FP4 from SMEM (2 values per byte), NOT unpacked.
|
||||
// _unpacksmem_t was for mxf8f6f4 which reads FP4 as FP8 (1 byte/element).
|
||||
// For mxf4nvf4: sizeof_bits = 4, SMEM stride = BLOCK_K/2 bytes, UMMA_K = 64.
|
||||
static_assert(cutlass::sizeof_bits_v<a_dtype_t> == 4,
|
||||
static_assert(cutlass::sizeof_bits<a_dtype_t>::value == 4,
|
||||
"mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM");
|
||||
static_assert(cutlass::sizeof_bits_v<b_dtype_t> == 4,
|
||||
static_assert(cutlass::sizeof_bits<b_dtype_t>::value == 4,
|
||||
"mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM");
|
||||
|
||||
// MMA configs
|
||||
@@ -741,7 +741,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
// TMA copy tokens and SFA, then arrive at full barrier
|
||||
if (cute::elect_one_sync()) {
|
||||
tma::copy<BLOCK_K / 2, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(
|
||||
tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2);
|
||||
tensor_map_a_ptr, full_barriers[stage_idx], reinterpret_cast<uint8_t*>(smem_a[stage_idx]), k_idx, m_idx, 2);
|
||||
tma::copy<SF_BLOCK_M, 1, 0>(
|
||||
tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2);
|
||||
if (is_leader_cta) {
|
||||
@@ -785,7 +785,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
if (cute::elect_one_sync()) {
|
||||
// NVFP4: weights are packed E2M1, BLOCK_K elements = BLOCK_K/2 bytes
|
||||
tma::copy<BLOCK_K / 2, LOAD_BLOCK_N, kSwizzleBMode, uint8_t>(
|
||||
tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2);
|
||||
tensor_map_b_ptr, full_barriers[stage_idx], reinterpret_cast<uint8_t*>(smem_b[stage_idx]), k_idx, n_idx, 2);
|
||||
tma::copy<BLOCK_N, 1, 0>(
|
||||
tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2);
|
||||
if (is_leader_cta) {
|
||||
@@ -815,8 +815,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
||||
// NVFP4: UMMA descriptors use packed byte dimensions (BLOCK_K/2, uint8_t)
|
||||
// because sizeof(float_e2m1_t)=1 but real stride is BLOCK_K/2 bytes per K-row
|
||||
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_M, BLOCK_K / 2, kSwizzleAMode, false, uint8_t>(smem_a[0], 0, 0);
|
||||
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_N, BLOCK_K / 2, kSwizzleBMode, false, uint8_t>(smem_b[0], 0, 0);
|
||||
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_M, BLOCK_K / 2, kSwizzleAMode, false, uint8_t>(reinterpret_cast<uint8_t*>(smem_a[0]), 0, 0);
|
||||
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_N, BLOCK_K / 2, kSwizzleBMode, false, uint8_t>(reinterpret_cast<uint8_t*>(smem_b[0]), 0, 0);
|
||||
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
||||
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
||||
|
||||
@@ -1307,10 +1307,15 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
constexpr uint32_t kNumChunkSlots = 3;
|
||||
constexpr uint32_t kNumMaxRegistersForBuffer = 128;
|
||||
|
||||
// NOTES: either 1 or 2 chunks for simplicity
|
||||
// NOTES: 1, 2, or 4 chunks depending on smem/register constraints
|
||||
// NOTES: Restrict on both smem and register
|
||||
constexpr uint32_t kNumChunks =
|
||||
kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE and kHidden <= 32 * kNumMaxRegistersForBuffer ? 1 : 2;
|
||||
kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE
|
||||
and kHidden <= 32 * kNumMaxRegistersForBuffer
|
||||
? 1
|
||||
: kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / 2 <= SMEM_BEFORE_BARRIER_SIZE
|
||||
? 2
|
||||
: 4;
|
||||
constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks;
|
||||
constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4);
|
||||
constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32;
|
||||
|
||||
Reference in New Issue
Block a user