fix: dispatch TMA byte counts for FP4 (kHidden/2), rename fp8→fp4 layout refs
This commit is contained in:
@@ -109,10 +109,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// Registered inputs
|
||||
const auto input_token_buffer = layout::Buffer(
|
||||
fp8_token_layout, 1, kNumMaxTokensPerRank,
|
||||
fp4_token_layout, 1, kNumMaxTokensPerRank,
|
||||
workspace.get_end_ptr());
|
||||
const auto input_sf_buffer = layout::Buffer(
|
||||
fp8_sf_layout, 1, kNumMaxTokensPerRank,
|
||||
nvfp4_sf_layout, 1, kNumMaxTokensPerRank,
|
||||
input_token_buffer.get_end_ptr());
|
||||
const auto input_topk_idx_buffer = layout::Buffer(
|
||||
input_topk_idx_layout, 1, kNumMaxTokensPerRank,
|
||||
@@ -140,10 +140,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// L1 inputs
|
||||
const auto l1_token_buffer = layout::Buffer(
|
||||
fp8_token_layout, 1, kNumMaxPoolTokens,
|
||||
fp4_token_layout, 1, kNumMaxPoolTokens,
|
||||
input_topk_weights_buffer.get_end_ptr());
|
||||
const auto l1_sf_buffer = layout::Buffer(
|
||||
fp8_sf_layout, 1, kNumPaddedSFPoolTokens,
|
||||
nvfp4_sf_layout, 1, kNumPaddedSFPoolTokens,
|
||||
l1_token_buffer.get_end_ptr());
|
||||
const auto l1_topk_weights_buffer = layout::Buffer(
|
||||
l1_topk_weights_layout, 1, kNumMaxPoolTokens,
|
||||
@@ -151,11 +151,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// L2 inputs
|
||||
const auto l2_token_buffer = layout::Buffer(
|
||||
fp8_intermediate_token_layout, 1, kNumMaxPoolTokens,
|
||||
fp4_intermediate_token_layout, 1, kNumMaxPoolTokens,
|
||||
l1_topk_weights_buffer.get_end_ptr()
|
||||
);
|
||||
const auto l2_sf_buffer = layout::Buffer(
|
||||
fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens,
|
||||
nvfp4_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens,
|
||||
l2_token_buffer.get_end_ptr()
|
||||
);
|
||||
|
||||
@@ -240,7 +240,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
// Assign shared memory for dispatch warps
|
||||
const auto smem_expert_count = reinterpret_cast<uint32_t*>(smem_buffer);
|
||||
const auto smem_send_buffers = layout::Buffer(
|
||||
fp8_token_layout, kNumDispatchWarps, 1,
|
||||
fp4_token_layout, kNumDispatchWarps, 1,
|
||||
math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE));
|
||||
|
||||
// GEMM shared memory: C/D, A, B
|
||||
@@ -561,11 +561,12 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// TMA load token from remote rank into shared memory
|
||||
if (cute::elect_one_sync()) {
|
||||
// NVFP4: activations are E2M1 packed, kHidden/2 bytes per token
|
||||
ptx::tma_load_1d(
|
||||
pull_buffer.get_base_ptr(),
|
||||
sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(),
|
||||
current_rank_in_expert_idx),
|
||||
pull_mbarrier, kHidden);
|
||||
pull_mbarrier, kHidden / 2); // NVFP4: E2M1 packed, half the bytes
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
@@ -597,7 +598,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
*l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr<float>() = weight;
|
||||
|
||||
// Wait for TMA token load to complete
|
||||
ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden);
|
||||
ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden / 2); // NVFP4: E2M1 packed, half the bytes
|
||||
ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase);
|
||||
|
||||
// Store token to local L1 buffer via TMA
|
||||
|
||||
Reference in New Issue
Block a user