fix: dispatch TMA byte counts for FP4 (kHidden/2), rename fp8→fp4 layout refs

This commit is contained in:
2026-05-11 20:47:58 +00:00
parent b3d1aae038
commit a554de8b24

View File

@@ -109,10 +109,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// Registered inputs
const auto input_token_buffer = layout::Buffer(
fp8_token_layout, 1, kNumMaxTokensPerRank,
fp4_token_layout, 1, kNumMaxTokensPerRank,
workspace.get_end_ptr());
const auto input_sf_buffer = layout::Buffer(
fp8_sf_layout, 1, kNumMaxTokensPerRank,
nvfp4_sf_layout, 1, kNumMaxTokensPerRank,
input_token_buffer.get_end_ptr());
const auto input_topk_idx_buffer = layout::Buffer(
input_topk_idx_layout, 1, kNumMaxTokensPerRank,
@@ -140,10 +140,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// L1 inputs
const auto l1_token_buffer = layout::Buffer(
fp8_token_layout, 1, kNumMaxPoolTokens,
fp4_token_layout, 1, kNumMaxPoolTokens,
input_topk_weights_buffer.get_end_ptr());
const auto l1_sf_buffer = layout::Buffer(
fp8_sf_layout, 1, kNumPaddedSFPoolTokens,
nvfp4_sf_layout, 1, kNumPaddedSFPoolTokens,
l1_token_buffer.get_end_ptr());
const auto l1_topk_weights_buffer = layout::Buffer(
l1_topk_weights_layout, 1, kNumMaxPoolTokens,
@@ -151,11 +151,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// L2 inputs
const auto l2_token_buffer = layout::Buffer(
fp8_intermediate_token_layout, 1, kNumMaxPoolTokens,
fp4_intermediate_token_layout, 1, kNumMaxPoolTokens,
l1_topk_weights_buffer.get_end_ptr()
);
const auto l2_sf_buffer = layout::Buffer(
fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens,
nvfp4_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens,
l2_token_buffer.get_end_ptr()
);
@@ -240,7 +240,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// Assign shared memory for dispatch warps
const auto smem_expert_count = reinterpret_cast<uint32_t*>(smem_buffer);
const auto smem_send_buffers = layout::Buffer(
fp8_token_layout, kNumDispatchWarps, 1,
fp4_token_layout, kNumDispatchWarps, 1,
math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE));
// GEMM shared memory: C/D, A, B
@@ -561,11 +561,12 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// TMA load token from remote rank into shared memory
if (cute::elect_one_sync()) {
// NVFP4: activations are E2M1 packed, kHidden/2 bytes per token
ptx::tma_load_1d(
pull_buffer.get_base_ptr(),
sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(),
current_rank_in_expert_idx),
pull_mbarrier, kHidden);
pull_mbarrier, kHidden / 2); // NVFP4: E2M1 packed, half the bytes
}
__syncwarp();
@@ -597,7 +598,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
*l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr<float>() = weight;
// Wait for TMA token load to complete
ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden);
ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden / 2); // NVFP4: E2M1 packed, half the bytes
ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase);
// Store token to local L1 buffer via TMA