diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 38a887d..dc77374 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -172,9 +172,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // mxf4nvf4 reads packed FP4 from SMEM (2 values per byte), NOT unpacked. // _unpacksmem_t was for mxf8f6f4 which reads FP4 as FP8 (1 byte/element). // For mxf4nvf4: sizeof_bits = 4, SMEM stride = BLOCK_K/2 bytes, UMMA_K = 64. - static_assert(cutlass::sizeof_bits_v == 4, + static_assert(cutlass::sizeof_bits::value == 4, "mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM"); - static_assert(cutlass::sizeof_bits_v == 4, + static_assert(cutlass::sizeof_bits::value == 4, "mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM"); // MMA configs @@ -741,7 +741,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // TMA copy tokens and SFA, then arrive at full barrier if (cute::elect_one_sync()) { tma::copy( - tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2); + tensor_map_a_ptr, full_barriers[stage_idx], reinterpret_cast(smem_a[stage_idx]), k_idx, m_idx, 2); tma::copy( tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2); if (is_leader_cta) { @@ -785,7 +785,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, if (cute::elect_one_sync()) { // NVFP4: weights are packed E2M1, BLOCK_K elements = BLOCK_K/2 bytes tma::copy( - tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2); + tensor_map_b_ptr, full_barriers[stage_idx], reinterpret_cast(smem_b[stage_idx]), k_idx, n_idx, 2); tma::copy( tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2); if (is_leader_cta) { @@ -815,8 +815,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); // NVFP4: UMMA descriptors use packed byte dimensions (BLOCK_K/2, uint8_t) // because sizeof(float_e2m1_t)=1 but real stride is BLOCK_K/2 bytes per K-row - auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); - auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + auto a_desc = mma::sm100::make_umma_desc(reinterpret_cast(smem_a[0]), 0, 0); + auto b_desc = mma::sm100::make_umma_desc(reinterpret_cast(smem_b[0]), 0, 0); uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; @@ -1307,10 +1307,15 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, constexpr uint32_t kNumChunkSlots = 3; constexpr uint32_t kNumMaxRegistersForBuffer = 128; - // NOTES: either 1 or 2 chunks for simplicity + // NOTES: 1, 2, or 4 chunks depending on smem/register constraints // NOTES: Restrict on both smem and register constexpr uint32_t kNumChunks = - kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE and kHidden <= 32 * kNumMaxRegistersForBuffer ? 1 : 2; + kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE + and kHidden <= 32 * kNumMaxRegistersForBuffer + ? 1 + : kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / 2 <= SMEM_BEFORE_BARRIER_SIZE + ? 2 + : 4; constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks; constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4); constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32;