NVFP4: fix SF pipeline — 2 K-cols per BLOCK_K for group=16
- TMA: issue two tma::copy calls per K-block (K_box=1, 2 SF K-columns) - UTCCP: double loop for 2 K-columns, correct SMEM offsets - TMEM: double SFA/SFB column counts (SF_BLOCK_M/32 * 2) - Heuristic: fix smem_size (2× SF, packed FP4 A/B, packed send buffers, no amax) - Staging kernel: fix double-count bug in packed_k_mask
This commit is contained in:
@@ -138,32 +138,37 @@ static std::pair<int, int> get_pipeline_config_for_mega_moe(
|
||||
// Dispatch region
|
||||
const int smem_expert_count_size = align(
|
||||
num_experts * static_cast<int>(sizeof(uint32_t)), kSmemAlignment);
|
||||
// NVFP4: dispatch send buffers use packed E2M1 tokens (hidden/2 bytes per token)
|
||||
const int smem_send_buffers_size = align(
|
||||
static_cast<int>(layout::Buffer(layout::Data(hidden), num_dispatch_warps, 1).get_num_bytes()),
|
||||
static_cast<int>(layout::Buffer(layout::Data(hidden / 2), num_dispatch_warps, 1).get_num_bytes()),
|
||||
kSmemAlignment);
|
||||
const int smem_dispatch_size = smem_expert_count_size + smem_send_buffers_size;
|
||||
|
||||
// C/D output region: max of L1 FP8 (2 TMA stages, BLOCK_N/2 post-SwiGLU) and L2 BF16 (1 stage)
|
||||
// C/D output region: max of L1 packed E2M1 (2 TMA stages, BLOCK_N/4 bytes per row) and L2 BF16 (1 stage)
|
||||
// NVFP4 L1 output: packed E2M1 (2 per byte), L1_OUT_BLOCK_N = block_n/2, bytes = L1_OUT_BLOCK_N/2 = block_n/4
|
||||
const auto num_epilogue_warpgroups = num_epilogue_warps / 4;
|
||||
const int smem_cd_l1 = num_epilogue_warpgroups * store_block_m * (block_n / 2) * kNumTMAStoreStages;
|
||||
const int smem_cd_l1 = num_epilogue_warpgroups * store_block_m * (block_n / 4) * kNumTMAStoreStages;
|
||||
const int smem_cd_l2 = num_epilogue_warpgroups * store_block_m * block_n * static_cast<int>(sizeof(nv_bfloat16));
|
||||
const int smem_cd = std::max(smem_cd_l1, smem_cd_l2);
|
||||
|
||||
// Barriers (stage-independent): dispatch + tensor memory full/empty + combine (2 per epilogue warp)
|
||||
const int smem_barriers = (num_dispatch_warps + kNumEpilogueStages * 2 + num_epilogue_warps * 2) * 8;
|
||||
|
||||
// Amax reduction
|
||||
const int smem_amax_reduction = store_block_m * num_epilogue_warps * static_cast<int>(sizeof(float));
|
||||
// NVFP4: no SMEM amax reduction needed (each warp computes its own amax)
|
||||
const int smem_amax_reduction = 0;
|
||||
|
||||
// Tensor memory pointer
|
||||
const int smem_tmem_ptr = 4;
|
||||
|
||||
// SF is aligned to UTCCP 128-element granularity
|
||||
const int smem_sfa_per_stage = sf_block_m * 4;
|
||||
const int smem_sfb_per_stage = sf_block_n * 4;
|
||||
// NVFP4: group=16 → 2 SF K-columns per BLOCK_K (128/16/4=2)
|
||||
// Each K-column: sf_block_m * 4 bytes (uint32), total = 2× the MXFP4 SF size
|
||||
const int smem_sfa_per_stage = sf_block_m * 4 * 2; // 2 K-cols for NVFP4
|
||||
const int smem_sfb_per_stage = sf_block_n * 4 * 2; // 2 K-cols for NVFP4
|
||||
|
||||
// Per-stage: A tile + B tile + SFA tile + SFB tile + full/empty barriers
|
||||
const int smem_per_stage = load_block_m * block_k + block_n * block_k + smem_sfa_per_stage + smem_sfb_per_stage + 2 * 8;
|
||||
// NVFP4: packed E2M1 (2 per byte), so A/B tiles use BLOCK_K/2 bytes per row
|
||||
const int smem_per_stage = load_block_m * (block_k / 2) + block_n * (block_k / 2) + smem_sfa_per_stage + smem_sfb_per_stage + 2 * 8;
|
||||
|
||||
// Fixed total
|
||||
const int smem_fixed = smem_dispatch_size + smem_cd + smem_amax_reduction + smem_barriers + smem_tmem_ptr;
|
||||
|
||||
@@ -237,10 +237,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// Tensor memory size
|
||||
constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages;
|
||||
// NVFP4 scale_vec::4X: UTCCP writes 4 TMEM cols per 128-elem group (same as 1X).
|
||||
// The 4X flag only changes MMA scale interpretation, not UTCCP layout.
|
||||
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
|
||||
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
|
||||
// NVFP4 group=16: 2 SF K-columns per BLOCK_K, so 2× the TMEM cols vs MXFP4
|
||||
// Each 128-elem UTCCP group → 4 TMEM cols, and we have 2 groups per BLOCK_K
|
||||
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32 * 2; // 2 K-cols
|
||||
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32 * 2; // 2 K-cols
|
||||
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
|
||||
constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
|
||||
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
|
||||
@@ -734,7 +734,6 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
uint32_t m_idx = pool_block_idx * BLOCK_M;
|
||||
uint32_t k_idx = k_block_idx * (BLOCK_K / 2); // packed FP4: byte offset
|
||||
uint32_t sfa_m_idx = pool_block_idx * SF_BLOCK_M;
|
||||
uint32_t sfa_k_idx = k_block_idx;
|
||||
|
||||
// Add 2 CTA offsets for non-leader CTA
|
||||
if (not is_leader_cta)
|
||||
@@ -744,8 +743,15 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
if (cute::elect_one_sync()) {
|
||||
tma::copy<BLOCK_K / 2, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(
|
||||
tensor_map_a_ptr, full_barriers[stage_idx], reinterpret_cast<uint8_t*>(smem_a[stage_idx]), k_idx, m_idx, 2);
|
||||
tma::copy<SF_BLOCK_M, 1, 0>(
|
||||
tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2);
|
||||
// NVFP4 group=16: 2 SF K-columns per BLOCK_K (128/16/4=2), but TMA K_box=1
|
||||
// Issue two TMA copies per K-block to load both SF columns
|
||||
for (uint32_t kk = 0; kk < 2; ++kk) {
|
||||
const uint32_t sfa_k_idx_full = k_block_idx * 2 + kk;
|
||||
tma::copy<SF_BLOCK_M, 1, 0>(
|
||||
tensor_map_sfa_ptr, full_barriers[stage_idx],
|
||||
smem_sfa[stage_idx] + kk * SF_BLOCK_M, // uint32* offset: SF_BLOCK_M ints per call
|
||||
sfa_m_idx, sfa_k_idx_full, 2);
|
||||
}
|
||||
if (is_leader_cta) {
|
||||
full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE * 2 + SMEM_SFA_SIZE_PER_STAGE * 2);
|
||||
} else {
|
||||
@@ -781,15 +787,20 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N;
|
||||
uint32_t k_idx = k_block_idx * (BLOCK_K / 2); // packed FP4: byte offset
|
||||
uint32_t sfb_n_idx = n_block_idx * BLOCK_N;
|
||||
uint32_t sfb_k_idx = local_expert_idx * shape_sfb_k + k_block_idx;
|
||||
|
||||
// TMA copy weights with SF
|
||||
if (cute::elect_one_sync()) {
|
||||
// NVFP4: weights are packed E2M1, BLOCK_K elements = BLOCK_K/2 bytes
|
||||
tma::copy<BLOCK_K / 2, LOAD_BLOCK_N, kSwizzleBMode, uint8_t>(
|
||||
tensor_map_b_ptr, full_barriers[stage_idx], reinterpret_cast<uint8_t*>(smem_b[stage_idx]), k_idx, n_idx, 2);
|
||||
tma::copy<BLOCK_N, 1, 0>(
|
||||
tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2);
|
||||
// NVFP4 group=16: 2 SF K-columns per BLOCK_K, but TMA K_box=1
|
||||
for (uint32_t kk = 0; kk < 2; ++kk) {
|
||||
const uint32_t sfb_k_idx_full = local_expert_idx * shape_sfb_k + k_block_idx * 2 + kk;
|
||||
tma::copy<BLOCK_N, 1, 0>(
|
||||
tensor_map_sfb_ptr, full_barriers[stage_idx],
|
||||
smem_sfb[stage_idx] + kk * BLOCK_N, // uint32* offset: BLOCK_N ints per call
|
||||
sfb_n_idx, sfb_k_idx_full, 2);
|
||||
}
|
||||
if (is_leader_cta) {
|
||||
full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE * 2);
|
||||
} else {
|
||||
@@ -868,20 +879,28 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx);
|
||||
if (cute::elect_one_sync()) {
|
||||
// UTCCP copy SFA and SFB to TMEM
|
||||
// NVFP4: scale_vec::4X, each 128-element block → 8 TMEM cols
|
||||
// NVFP4: group=16 → 2 SF K-columns per BLOCK_K (128/16/4=2)
|
||||
// Each UTCCP call moves 128 int32s → 4 TMEM cols
|
||||
// We need 2 UTCCP calls per SF: one per K-column
|
||||
using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
|
||||
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
|
||||
for (uint32_t kk = 0; kk < 2; ++kk) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfa[stage_idx] + kk * SF_BLOCK_M + i * kNumUTCCPAlignedElems;
|
||||
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + (kk * (SF_BLOCK_M / kNumUTCCPAlignedElems) + i) * 4);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
|
||||
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
|
||||
for (uint32_t kk = 0; kk < 2; ++kk) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfb[stage_idx] + kk * BLOCK_N + i * kNumUTCCPAlignedElems;
|
||||
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + (kk * (SF_BLOCK_N / kNumUTCCPAlignedElems) + i) * 4);
|
||||
}
|
||||
}
|
||||
|
||||
// Issue UMMA
|
||||
|
||||
Reference in New Issue
Block a user