From dcebe033e26f5c9802b4f82163085d123bf11612 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 08:36:59 +0000 Subject: [PATCH] fix: use scale_vec::2X (block32) for SM100 B200 compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit scale_vec::4X (block16) requires SM103/SM120 (B300/GB300), not SM100 (B200). Revert to block32 with UE4M3 scales. Same TMEM layout as MXFP4 but with UE4M3 scale format instead of UE8M0. Changes: - kGranK: 16 → 32 - PTX: scale_vec::4X → scale_vec::2X - SF layout: same as MXFP4 (K/32, K/128 for int32 packed) - UTCCP: i*8 → i*4 (2X layout, same as MXFP4) - TMEM columns: same as MXFP4 (SF_BLOCK_M/32, SF_BLOCK_N/32) - Python: merge NVFP4 block16→block32 scales (max of adjacent pairs) - recipe: (1,1,16) → (1,1,32) --- csrc/apis/layout.hpp | 4 +-- csrc/apis/mega_nvfp4.hpp | 14 ++++---- .../impls/sm100_fp8_nvfp4_mega_moe.cuh | 34 +++++++++---------- deep_gemm/include/deep_gemm/ptx/tcgen05.cuh | 4 +-- deep_gemm/mega/__init__.py | 26 ++++++++++---- 5 files changed, 47 insertions(+), 35 deletions(-) diff --git a/csrc/apis/layout.hpp b/csrc/apis/layout.hpp index f369e9b..15b104b 100644 --- a/csrc/apis/layout.hpp +++ b/csrc/apis/layout.hpp @@ -53,8 +53,8 @@ static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, } // (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major - // Supports gran_k=16 (NVFP4), 32 (MXFP4), 128 (FP8) - if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 16 or gran_k == 32 or gran_k == 128) and arch_major == 10) + // Supports gran_k=32 (MXFP4 and NVFP4-block32), 128 (FP8) + if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10) return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt); DG_HOST_UNREACHABLE("Unknown SF transformation"); diff --git a/csrc/apis/mega_nvfp4.hpp b/csrc/apis/mega_nvfp4.hpp index d47bfb0..83b215a 100644 --- a/csrc/apis/mega_nvfp4.hpp +++ b/csrc/apis/mega_nvfp4.hpp @@ -30,8 +30,8 @@ get_symm_buffer_size_for_nvfp4_mega_moe( const auto fp8_token_layout = layout::Data(hidden); const auto bf16_token_layout = layout::Data(hidden * 2); const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden); - const auto nvfp4_sf_layout = layout::Data(hidden / 16); - const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16); + const auto nvfp4_sf_layout = layout::Data(hidden / 32); + const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 32); const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false); const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false); const auto l1_topk_weights_layout = layout::Data(sizeof(float), false); @@ -86,7 +86,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // Check SF buffer requirements // NVFP4: hidden must be divisible by 64 (4 UE4M3 scales per int32, group_size=16) - DG_HOST_ASSERT(hidden % 64 == 0 and intermediate_hidden % 64 == 0); + DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0); // Slice function @@ -98,7 +98,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // NVFP4 SF: K/16 bytes per token, packed as K/64 int32 auto x_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_sf_buffer.base)), - {num_max_tokens_per_rank, hidden / 64}, + {num_max_tokens_per_rank, hidden / 128}, torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); auto topk_idx = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_idx_buffer.base)), @@ -115,7 +115,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // NVFP4 L1 SF: M-major, K/64 int32 auto l1_acts_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_sf_buffer.base)), - {num_max_padded_sf_pool_tokens, hidden / 64}, + {num_max_padded_sf_pool_tokens, hidden / 128}, {1, num_max_padded_sf_pool_tokens}, torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); auto l2_acts = torch::from_blob( @@ -125,7 +125,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // NVFP4 L2 SF: M-major, K/64 int32 auto l2_acts_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_sf_buffer.base)), - {num_max_padded_sf_pool_tokens, intermediate_hidden / 64}, + {num_max_padded_sf_pool_tokens, intermediate_hidden / 128}, {1, num_max_padded_sf_pool_tokens}, torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf); @@ -153,7 +153,7 @@ static void fp8_nvfp4_mega_moe( // Config checks const auto num_tokens = static_cast(y.size(0)); const auto [rm, rn, rk] = recipe; - DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 16); // NVFP4: group_size=16 + DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 32); // NVFP4 block32: group_size=32 DG_HOST_ASSERT(activation == "swiglu"); // Activation checks diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index a2170e5..1becc65 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -98,9 +98,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, constexpr auto fp8_token_layout = layout::Data(kHidden); constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); - // NVFP4: group_size=16, so SF stride is K/16 (twice as many scales as MXFP4) - constexpr auto fp8_sf_layout = layout::Data(kHidden / 16); - constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16); + // NVFP4: scale_vec::2X (block32) on SM100, same SF stride as MXFP4 + constexpr auto fp8_sf_layout = layout::Data(kHidden / 32); + constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 32); constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false); @@ -120,8 +120,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, input_topk_idx_buffer.get_end_ptr()); // SF and its buffer configs - // NVFP4: group_size=16 → kGranK=16 (vs MXFP4's 32) - constexpr uint32_t kGranK = 16; + // NVFP4 on SM100: scale_vec::2X (block32), group_size=32 with UE4M3 scales + // Note: scale_vec::4X (block16) requires SM103/SM120 (B300/GB300), not SM100 + // So we use block32 and merge pairs of NVFP4 block16 scales + constexpr uint32_t kGranK = 32; // For NVFP4 scale_vec::4X, UTCCP alignment is still 128 elements constexpr uint32_t kNumUTCCPAlignedElems = 128; DG_STATIC_ASSERT(SF_BLOCK_M == math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems), "Invalid SF_BLOCK_M"); @@ -220,11 +222,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Tensor memory size constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; - // NVFP4: scale_vec::4X → 4 SF per UMMA atom row → 4 TMEM cols per SF row - // For bM=128, SFA uses 4 rows × 4 cols = 16 TMEM columns - // SFB uses BLOCK_N/32 rows × 4 cols - constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32 * 4; - constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32 * 4; + // NVFP4 scale_vec::2X: same TMEM layout as MXFP4 + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; @@ -563,9 +563,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, __syncwarp(); // Load and store SF (overlaps with TMA token load) - // NVFP4: group_size=16, 4 UE4M3 scales per uint32 - constexpr uint32_t kNumSFUint32 = kHidden / 64; - DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 64 == 0, "Invalid SF"); + // NVFP4 block32: same SF uint32 count as MXFP4 + constexpr uint32_t kNumSFUint32 = kHidden / 128; + DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 128 == 0, "Invalid SF"); const auto remote_sf_ptr = sym_buffer.map( input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr(), current_rank_in_expert_idx); @@ -846,21 +846,19 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); if (cute::elect_one_sync()) { // UTCCP copy SFA and SFB to TMEM - // NVFP4: scale_vec::4X, each 128-element block → 8 TMEM cols + // NVFP4 scale_vec::2X: same layout as MXFP4 using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta; - #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); - // NVFP4 4X: 8 TMEM columns per 128-element SF group - cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 8); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); } #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); - cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 8); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); } // Issue UMMA diff --git a/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh b/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh index f4ed99f..f1bfbb1 100644 --- a/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh +++ b/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh @@ -153,7 +153,7 @@ struct SM100_MMA_MXF4NVF4_2x1SM_SS { "{\n\t" ".reg .pred p;\n\t" "setp.ne.b32 p, %4, 0;\n\t" - "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" "}\n" : : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), @@ -175,7 +175,7 @@ struct SM100_MMA_MXF4NVF4_SS { "{\n\t" ".reg .pred p;\n\t" "setp.ne.b32 p, %4, 0;\n\t" - "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" "}\n" : : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 57c807e..824fcd9 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -162,6 +162,20 @@ def transform_nvfp4_weights_for_mega_moe( l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2) l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2) + # Merge NVFP4 block16 scales → block32 for SM100 (scale_vec::2X) + # B200 (SM100) doesn't support scale_vec::4X (block16) — requires SM103/SM120 + # Take the max of each pair of adjacent block16 scales for block32 + def merge_block16_to_block32(sf): + # sf: (experts, mn, K//16) float8_e4m3fn + # output: (experts, mn, K//32) float8_e4m3fn + sf_f32 = sf.to(torch.float32) + # Take max of adjacent pairs (preserves magnitude, avoids underflow) + sf_merged = torch.maximum(sf_f32[..., 0::2], sf_f32[..., 1::2]) + return sf_merged.clamp(0.0, 448.0).to(torch.float8_e4m3fn) + + l1_sf_32 = merge_block16_to_block32(l1_sf) + l2_sf_32 = merge_block16_to_block32(l2_sf) + num_experts = l1_weights[0].shape[0] l1_n = l1_weights[0].shape[1] l1_k = l1_weights[0].shape[2] * 2 @@ -179,8 +193,8 @@ def transform_nvfp4_weights_for_mega_moe( (sf_u8[..., 3::4].to(torch.int32) << 24)) return packed.contiguous() - l1_sf_packed = pack_ue4m3_to_int32(l1_sf) - l2_sf_packed = pack_ue4m3_to_int32(l2_sf) + l1_sf_packed = pack_ue4m3_to_int32(l1_sf_32) + l2_sf_packed = pack_ue4m3_to_int32(l2_sf_32) # Transpose to MN-major layout (stride(-2)=1) and make contiguous # transform_sf_into_required_layout expects MN-major input for TMA stride checks @@ -188,11 +202,11 @@ def transform_nvfp4_weights_for_mega_moe( l2_sf_mn = l2_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1) # Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function - # recipe (1, 16): gran_mn=1, gran_k=16 + # recipe (1, 32): gran_mn=1, gran_k=16 l1_sf_transformed = transform_sf_into_required_layout( - l1_sf_mn, l1_n, l1_k, (1, 16), num_experts) + l1_sf_mn, l1_n, l1_k, (1, 32), num_experts) l2_sf_transformed = transform_sf_into_required_layout( - l2_sf_mn, l2_n, l2_k, (1, 16), num_experts) + l2_sf_mn, l2_n, l2_k, (1, 32), num_experts) # L1: interleave gate/up l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_packed)) @@ -267,7 +281,7 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor, l2_weights: Tuple[torch.Tensor, torch.Tensor], sym_buffer: SymmBuffer, cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, - recipe: Tuple[int, int, int] = (1, 1, 16), + recipe: Tuple[int, int, int] = (1, 1, 32), activation: str = 'swiglu', activation_clamp: Optional[float] = None, fast_math: bool = True):