diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 321d4e1..d48dada 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -109,10 +109,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Registered inputs const auto input_token_buffer = layout::Buffer( - fp8_token_layout, 1, kNumMaxTokensPerRank, + fp4_token_layout, 1, kNumMaxTokensPerRank, workspace.get_end_ptr()); const auto input_sf_buffer = layout::Buffer( - fp8_sf_layout, 1, kNumMaxTokensPerRank, + nvfp4_sf_layout, 1, kNumMaxTokensPerRank, input_token_buffer.get_end_ptr()); const auto input_topk_idx_buffer = layout::Buffer( input_topk_idx_layout, 1, kNumMaxTokensPerRank, @@ -140,10 +140,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // L1 inputs const auto l1_token_buffer = layout::Buffer( - fp8_token_layout, 1, kNumMaxPoolTokens, + fp4_token_layout, 1, kNumMaxPoolTokens, input_topk_weights_buffer.get_end_ptr()); const auto l1_sf_buffer = layout::Buffer( - fp8_sf_layout, 1, kNumPaddedSFPoolTokens, + nvfp4_sf_layout, 1, kNumPaddedSFPoolTokens, l1_token_buffer.get_end_ptr()); const auto l1_topk_weights_buffer = layout::Buffer( l1_topk_weights_layout, 1, kNumMaxPoolTokens, @@ -151,11 +151,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // L2 inputs const auto l2_token_buffer = layout::Buffer( - fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, + fp4_intermediate_token_layout, 1, kNumMaxPoolTokens, l1_topk_weights_buffer.get_end_ptr() ); const auto l2_sf_buffer = layout::Buffer( - fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, + nvfp4_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, l2_token_buffer.get_end_ptr() ); @@ -240,7 +240,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Assign shared memory for dispatch warps const auto smem_expert_count = reinterpret_cast(smem_buffer); const auto smem_send_buffers = layout::Buffer( - fp8_token_layout, kNumDispatchWarps, 1, + fp4_token_layout, kNumDispatchWarps, 1, math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE)); // GEMM shared memory: C/D, A, B @@ -561,11 +561,12 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // TMA load token from remote rank into shared memory if (cute::elect_one_sync()) { + // NVFP4: activations are E2M1 packed, kHidden/2 bytes per token ptx::tma_load_1d( pull_buffer.get_base_ptr(), sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), current_rank_in_expert_idx), - pull_mbarrier, kHidden); + pull_mbarrier, kHidden / 2); // NVFP4: E2M1 packed, half the bytes } __syncwarp(); @@ -597,7 +598,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; // Wait for TMA token load to complete - ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); + ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden / 2); // NVFP4: E2M1 packed, half the bytes ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); // Store token to local L1 buffer via TMA