From a554de8b24fb160ce472cb407a90e2e4b8dbf525 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 20:47:58 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20dispatch=20TMA=20byte=20counts=20for=20F?= =?UTF-8?q?P4=20(kHidden/2),=20rename=20fp8=E2=86=92fp4=20layout=20refs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../impls/sm100_fp8_nvfp4_mega_moe.cuh | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 321d4e1..d48dada 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -109,10 +109,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Registered inputs const auto input_token_buffer = layout::Buffer( - fp8_token_layout, 1, kNumMaxTokensPerRank, + fp4_token_layout, 1, kNumMaxTokensPerRank, workspace.get_end_ptr()); const auto input_sf_buffer = layout::Buffer( - fp8_sf_layout, 1, kNumMaxTokensPerRank, + nvfp4_sf_layout, 1, kNumMaxTokensPerRank, input_token_buffer.get_end_ptr()); const auto input_topk_idx_buffer = layout::Buffer( input_topk_idx_layout, 1, kNumMaxTokensPerRank, @@ -140,10 +140,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // L1 inputs const auto l1_token_buffer = layout::Buffer( - fp8_token_layout, 1, kNumMaxPoolTokens, + fp4_token_layout, 1, kNumMaxPoolTokens, input_topk_weights_buffer.get_end_ptr()); const auto l1_sf_buffer = layout::Buffer( - fp8_sf_layout, 1, kNumPaddedSFPoolTokens, + nvfp4_sf_layout, 1, kNumPaddedSFPoolTokens, l1_token_buffer.get_end_ptr()); const auto l1_topk_weights_buffer = layout::Buffer( l1_topk_weights_layout, 1, kNumMaxPoolTokens, @@ -151,11 +151,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // L2 inputs const auto l2_token_buffer = layout::Buffer( - fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, + fp4_intermediate_token_layout, 1, kNumMaxPoolTokens, l1_topk_weights_buffer.get_end_ptr() ); const auto l2_sf_buffer = layout::Buffer( - fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, + nvfp4_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, l2_token_buffer.get_end_ptr() ); @@ -240,7 +240,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Assign shared memory for dispatch warps const auto smem_expert_count = reinterpret_cast(smem_buffer); const auto smem_send_buffers = layout::Buffer( - fp8_token_layout, kNumDispatchWarps, 1, + fp4_token_layout, kNumDispatchWarps, 1, math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE)); // GEMM shared memory: C/D, A, B @@ -561,11 +561,12 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // TMA load token from remote rank into shared memory if (cute::elect_one_sync()) { + // NVFP4: activations are E2M1 packed, kHidden/2 bytes per token ptx::tma_load_1d( pull_buffer.get_base_ptr(), sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), current_rank_in_expert_idx), - pull_mbarrier, kHidden); + pull_mbarrier, kHidden / 2); // NVFP4: E2M1 packed, half the bytes } __syncwarp(); @@ -597,7 +598,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; // Wait for TMA token load to complete - ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); + ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden / 2); // NVFP4: E2M1 packed, half the bytes ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); // Store token to local L1 buffer via TMA