From fbdddaccf4bd365509e4a48173b33e29ed2d7c10 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 15:02:47 +0000 Subject: [PATCH] revert: restore mxf4nvf4/block16 code (correct path for sm_100a) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reverted to commit 36b439e's NVFP4 kernel code: - kGranK=16, mxf4nvf4.block_scale.scale_vec::4X - float_ue4m3_t instruction descriptor - Block16 SF layout (4X TMEM) - UE4M3 L1 epilogue - No UE4M3→UE8M0 conversion, no block16→block32 merge The mxf4nvf4.scale_vec::4X PTX instruction compiles successfully on both sm_100 and sm_100f with CUDA 13.0. The previous build 17 error was likely from a different cause, not the arch flag. Python: reverted transform_nvfp4_weights_for_mega_moe to use pack_ue4m3_to_int32 with gran_k=16, no UE8M0 conversion. --- csrc/apis/layout.hpp | 1 - csrc/apis/mega_nvfp4.hpp | 18 +-- .../impls/sm100_fp8_nvfp4_mega_moe.cuh | 59 ++++---- deep_gemm/include/deep_gemm/ptx/tcgen05.cuh | 4 +- deep_gemm/mega/__init__.py | 134 ++---------------- 5 files changed, 56 insertions(+), 160 deletions(-) diff --git a/csrc/apis/layout.hpp b/csrc/apis/layout.hpp index 15b104b..b404241 100644 --- a/csrc/apis/layout.hpp +++ b/csrc/apis/layout.hpp @@ -53,7 +53,6 @@ static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, } // (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major - // Supports gran_k=32 (MXFP4 and NVFP4-block32), 128 (FP8) if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10) return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt); diff --git a/csrc/apis/mega_nvfp4.hpp b/csrc/apis/mega_nvfp4.hpp index 7467519..d47bfb0 100644 --- a/csrc/apis/mega_nvfp4.hpp +++ b/csrc/apis/mega_nvfp4.hpp @@ -30,8 +30,8 @@ get_symm_buffer_size_for_nvfp4_mega_moe( const auto fp8_token_layout = layout::Data(hidden); const auto bf16_token_layout = layout::Data(hidden * 2); const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden); - const auto nvfp4_sf_layout = layout::Data(hidden / 32); - const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 32); + const auto nvfp4_sf_layout = layout::Data(hidden / 16); + const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16); const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false); const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false); const auto l1_topk_weights_layout = layout::Data(sizeof(float), false); @@ -86,7 +86,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // Check SF buffer requirements // NVFP4: hidden must be divisible by 64 (4 UE4M3 scales per int32, group_size=16) - DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); + DG_HOST_ASSERT(hidden % 64 == 0 and intermediate_hidden % 64 == 0); DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0); // Slice function @@ -98,7 +98,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // NVFP4 SF: K/16 bytes per token, packed as K/64 int32 auto x_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_sf_buffer.base)), - {num_max_tokens_per_rank, hidden / 128}, + {num_max_tokens_per_rank, hidden / 64}, torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); auto topk_idx = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_idx_buffer.base)), @@ -115,7 +115,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // NVFP4 L1 SF: M-major, K/64 int32 auto l1_acts_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_sf_buffer.base)), - {num_max_padded_sf_pool_tokens, hidden / 128}, + {num_max_padded_sf_pool_tokens, hidden / 64}, {1, num_max_padded_sf_pool_tokens}, torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); auto l2_acts = torch::from_blob( @@ -125,7 +125,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // NVFP4 L2 SF: M-major, K/64 int32 auto l2_acts_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_sf_buffer.base)), - {num_max_padded_sf_pool_tokens, intermediate_hidden / 128}, + {num_max_padded_sf_pool_tokens, intermediate_hidden / 64}, {1, num_max_padded_sf_pool_tokens}, torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf); @@ -153,7 +153,7 @@ static void fp8_nvfp4_mega_moe( // Config checks const auto num_tokens = static_cast(y.size(0)); const auto [rm, rn, rk] = recipe; - DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 32); // NVFP4 block32: group_size=32 + DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 16); // NVFP4: group_size=16 DG_HOST_ASSERT(activation == "swiglu"); // Activation checks @@ -175,8 +175,8 @@ static void fp8_nvfp4_mega_moe( DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous()); // Check weight SF layout for UE4M3 packing, MN-major, and TMA alignment - // NVFP4 block32: kGranK=32, SF packed as int32 (4 UE4M3 bytes per int32) - constexpr int kGranMN = 1, kGranK = 32; + // NVFP4: kGranK=16, SF packed as int32 (4 UE4M3 bytes per int32) + constexpr int kGranMN = 1, kGranK = 16; check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK, num_experts_per_rank, true, false, torch::kInt); check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK, diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 95f2aeb..a2170e5 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -98,9 +98,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, constexpr auto fp8_token_layout = layout::Data(kHidden); constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); - // NVFP4: scale_vec::2X (block32) on SM100, same SF stride as MXFP4 - constexpr auto fp8_sf_layout = layout::Data(kHidden / 32); - constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 32); + // NVFP4: group_size=16, so SF stride is K/16 (twice as many scales as MXFP4) + constexpr auto fp8_sf_layout = layout::Data(kHidden / 16); + constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16); constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false); @@ -120,10 +120,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, input_topk_idx_buffer.get_end_ptr()); // SF and its buffer configs - // NVFP4 on SM100: scale_vec::2X (block32), group_size=32 with UE4M3 scales - // Note: scale_vec::4X (block16) requires SM103/SM120 (B300/GB300), not SM100 - // So we use block32 and merge pairs of NVFP4 block16 scales - constexpr uint32_t kGranK = 32; + // NVFP4: group_size=16 → kGranK=16 (vs MXFP4's 32) + constexpr uint32_t kGranK = 16; // For NVFP4 scale_vec::4X, UTCCP alignment is still 128 elements constexpr uint32_t kNumUTCCPAlignedElems = 128; DG_STATIC_ASSERT(SF_BLOCK_M == math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems), "Invalid SF_BLOCK_M"); @@ -222,9 +220,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Tensor memory size constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; - // NVFP4 scale_vec::2X: same TMEM layout as MXFP4 - constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; - constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + // NVFP4: scale_vec::4X → 4 SF per UMMA atom row → 4 TMEM cols per SF row + // For bM=128, SFA uses 4 rows × 4 cols = 16 TMEM columns + // SFB uses BLOCK_N/32 rows × 4 cols + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32 * 4; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32 * 4; constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; @@ -563,9 +563,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, __syncwarp(); // Load and store SF (overlaps with TMA token load) - // NVFP4 block32: same SF uint32 count as MXFP4 - constexpr uint32_t kNumSFUint32 = kHidden / 128; - DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 128 == 0, "Invalid SF"); + // NVFP4: group_size=16, 4 UE4M3 scales per uint32 + constexpr uint32_t kNumSFUint32 = kHidden / 64; + DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 64 == 0, "Invalid SF"); const auto remote_sf_ptr = sym_buffer.map( input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr(), current_rank_in_expert_idx); @@ -785,11 +785,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // GEMM MMA issue warp (only the leader CTA will run) if (is_leader_cta) { - // NVFP4 on SM100: use mxf8f6f4 instruction with UE8M0 scales - // (mxf4nvf4 requires SM103+; B200 is SM100) - // We convert UE4M3→UE8M0 in the weight transformation + // NVFP4: use float_ue4m3_t scale factor type with mxf4nvf4 instruction + // NOTES: always swap A/B auto instr_desc = cute::UMMA::make_instr_desc_block_scaled< - b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t, + b_dtype_t, a_dtype_t, float, cutlass::float_ue4m3_t, UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K >(); @@ -847,19 +846,21 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); if (cute::elect_one_sync()) { // UTCCP copy SFA and SFB to TMEM - // NVFP4 scale_vec::2X: same layout as MXFP4 + // NVFP4: scale_vec::4X, each 128-element block → 8 TMEM cols using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta; + #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); - cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + // NVFP4 4X: 8 TMEM columns per 128-element SF group + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 8); } #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); - cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 8); } // Issue UMMA @@ -871,7 +872,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K); b_desc.lo = mma::sm100::advance_umma_desc_lo< cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K); - ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma( + // NVFP4: use mxf4nvf4 instruction with UE4M3 scales + ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma( b_desc, a_desc, accum_stage_idx * UMMA_N, k_block_idx > 0 or k > 0, runtime_instr_desc, kTmemStartColOfSFB, kTmemStartColOfSFA); @@ -1097,12 +1099,15 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4; const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast(sizeof(uint32_t)) + byte_idx; - // NVFP4 on SM100: convert float scale to UE8M0 (power-of-2) - // UE8M0: 8-bit exponent, no mantissa, represents 2^(exp-127) - sf_base_ptr[sf_addr] = - (*reinterpret_cast(&sf.x) >> 23); - sf_base_ptr[sf_addr + 4 * static_cast(sizeof(uint32_t))] = - (*reinterpret_cast(&sf.y) >> 23); + // NVFP4: convert float scale to UE4M3 format + // UE4M3: sign=0 + 4 exp + 3 mantissa, max=448 + auto to_ue4m3 = [](float v) -> uint8_t { + v = fmaxf(0.0f, fminf(v, 448.0f)); + cutlass::float_e4m3_t e4m3_val = cutlass::float_e4m3_t(v); + return reinterpret_cast(e4m3_val) & 0x7F; + }; + sf_base_ptr[sf_addr] = to_ue4m3(sf.x); + sf_base_ptr[sf_addr + 4 * static_cast(sizeof(uint32_t))] = to_ue4m3(sf.y); } __syncwarp(); } diff --git a/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh b/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh index f1bfbb1..f4ed99f 100644 --- a/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh +++ b/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh @@ -153,7 +153,7 @@ struct SM100_MMA_MXF4NVF4_2x1SM_SS { "{\n\t" ".reg .pred p;\n\t" "setp.ne.b32 p, %4, 0;\n\t" - "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" "}\n" : : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), @@ -175,7 +175,7 @@ struct SM100_MMA_MXF4NVF4_SS { "{\n\t" ".reg .pred p;\n\t" "setp.ne.b32 p, %4, 0;\n\t" - "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" "}\n" : : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index f30294e..52fbb35 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -138,93 +138,22 @@ def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: def transform_nvfp4_weights_for_mega_moe( l1_weights: Tuple[torch.Tensor, torch.Tensor], - l2_weights: Tuple[torch.Tensor, torch.Tensor], - l1_weight_scale_2: Optional[torch.Tensor] = None, - l2_weight_scale_2: Optional[torch.Tensor] = None + l2_weights: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: """Transform NVFP4 expert weights for the mega_moe kernel. - Uses deep_gemm.transform_sf_into_required_layout for proper TMA-aligned - UTCCP layout with recipe (1, 1, 16) for NVFP4 group_size=16. + NVFP4 weights come as (weight, scale) where: + - weight: uint8 E2M1 packed, shape (num_experts, N, K//2) + - scale: float8_e4m3fn UE4M3 block scales, shape (num_experts, N, K//16) + + The kernel expects (weight, packed_sf) where packed_sf is int32 UTCCP layout. """ - from deep_gemm import transform_sf_into_required_layout - - def fold_global_scale(sf: torch.Tensor, scale_2: Optional[torch.Tensor]) -> torch.Tensor: - if scale_2 is None: - return sf - sf_f32 = sf.to(torch.float32) - if scale_2.dim() == 1: - scale_2 = scale_2.view(-1, 1, 1) - sf_f32 = sf_f32 * scale_2 - sf_f32 = sf_f32.clamp(0.0, 448.0) - return sf_f32.to(torch.float8_e4m3fn) - - l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2) - l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2) - - # Merge NVFP4 block16 scales → block32 for SM100 (scale_vec::2X) - # B200 (SM100) doesn't support scale_vec::4X (block16) — requires SM103/SM120 - # Take the max of each pair of adjacent block16 scales for block32 - def merge_block16_to_block32(sf): - # sf: (experts, mn, K//16) float8_e4m3fn - # output: (experts, mn, K//32) uint8 (UE8M0) - # SM100 (B200) doesn't support mxf4nvf4 — must use mxf8f6f4 with UE8M0 scales - # Convert UE4M3 → float32 → UE8M0 (power-of-2) - sf_f32 = sf.to(torch.float32) - # Take max of adjacent pairs - sf_merged = torch.maximum(sf_f32[..., 0::2], sf_f32[..., 1::2]) - # Convert to UE8M0: extract exponent byte from float32 bit pattern - # UE8M0: uint8 = float32_bits[31:23] (8 exponent bits) - # Note: PyTorch doesn't support >> on uint32, cast to int32 first - sf_bits = sf_merged.view(torch.int32) # reinterpret float32 bits as int32 - sf_ue8m0 = ((sf_bits >> 23) & 0xFF).to(torch.uint8) - return sf_ue8m0 - - l1_sf_32 = merge_block16_to_block32(l1_sf) - l2_sf_32 = merge_block16_to_block32(l2_sf) - - num_experts = l1_weights[0].shape[0] - l1_n = l1_weights[0].shape[1] - l1_k = l1_weights[0].shape[2] * 2 - l2_n = l2_weights[0].shape[1] - l2_k = l2_weights[0].shape[2] * 2 - - # Pack UE8M0 (uint8) block scales into int32 for DeepGEMM TMA consumption - # Same packing as MXFP4: 4 uint8 → 1 int32 - def pack_uint8_to_int32(sf): - assert sf.dtype == torch.uint8 - assert sf.shape[-1] % 4 == 0 - packed = (sf[..., 0::4].to(torch.int32) | - (sf[..., 1::4].to(torch.int32) << 8) | - (sf[..., 2::4].to(torch.int32) << 16) | - (sf[..., 3::4].to(torch.int32) << 24)) - return packed.contiguous() - - l1_sf_packed = pack_uint8_to_int32(l1_sf_32) - l2_sf_packed = pack_uint8_to_int32(l2_sf_32) - - print(f"[NVFP4-MoE] l1_sf_32: shape={l1_sf_32.shape}, l1_sf_packed: shape={l1_sf_packed.shape}") - print(f"[NVFP4-MoE] l2_sf_32: shape={l2_sf_32.shape}, l2_sf_packed: shape={l2_sf_packed.shape}") - print(f"[NVFP4-MoE] l1_n={l1_n} l1_k={l1_k} l2_n={l2_n} l2_k={l2_k}") - - # Transpose to MN-major layout (stride(-2)=1) and make contiguous - # transform_sf_into_required_layout expects MN-major input for TMA stride checks - l1_sf_mn = l1_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1) - l2_sf_mn = l2_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1) - - # Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function - # recipe (1, 32): gran_mn=1, gran_k=16 - l1_sf_transformed = transform_sf_into_required_layout( - l1_sf_mn, l1_n, l1_k, (1, 32), num_experts) - l2_sf_transformed = transform_sf_into_required_layout( - l2_sf_mn, l2_n, l2_k, (1, 32), num_experts) - - # L1: interleave gate/up - l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_packed)) - # DeepGEMM expects int8 (kPackedFP4 = torch.kInt8) - l1_out = (l1_interleaved[0].view(torch.int8), l1_sf_transformed) - l2_out = (l2_weights[0].view(torch.int8), l2_sf_transformed) - return l1_out, l2_out + # L1: interleave gate/up, then pack + transpose SF for UTCCP + l1_interleaved = _interleave_l1_weights(l1_weights) + l1_weights = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1])) + # L2: only pack + transpose SF for UTCCP + l2_weights = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_weights[1])) + return l1_weights, l2_weights def fp8_fp4_mega_moe(y: torch.Tensor, @@ -250,49 +179,12 @@ def fp8_fp4_mega_moe(y: torch.Tensor, ) -def get_symm_buffer_for_nvfp4_mega_moe( - group: "dist.ProcessGroup", - num_experts: int, - num_max_tokens_per_rank: int, num_topk: int, - hidden: int, intermediate_hidden: int, - use_fp8_dispatch: bool = True, - activation: str = 'swiglu') -> SymmBuffer: - """Allocate a SymmBuffer sized for NVFP4 mega_moe (group_size=16).""" - from .. import _C - num_max_tokens_per_rank = align(num_max_tokens_per_rank, - _C.get_token_alignment_for_nvfp4_mega_moe()) - buf = SymmBuffer.__new__(SymmBuffer) - buf.group = group - buf.num_experts = num_experts - buf.num_max_tokens_per_rank = num_max_tokens_per_rank - buf.num_topk = num_topk - buf.hidden = hidden - buf.intermediate_hidden = intermediate_hidden - # Use NVFP4-specific buffer size (2x SF due to group_size=16) - num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_nvfp4_mega_moe( - group.size(), num_experts, - num_max_tokens_per_rank, num_topk, - hidden, intermediate_hidden, - use_fp8_dispatch, activation) - import torch.distributed._symmetric_memory as symm_mem - import torch.distributed as dist - buf.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda') - buf.handle = symm_mem.rendezvous(buf.buffer, group=group) - buf.buffer.zero_() - buf.group.barrier() - torch.cuda.synchronize() - buf.x, buf.x_sf, buf.topk_idx, buf.topk_weights, \ - buf.l1_acts, buf.l1_acts_sf, buf.l2_acts, buf.l2_acts_sf = \ - slice_input_buffers(buf.buffer) - return buf - - def fp8_nvfp4_mega_moe(y: torch.Tensor, l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], sym_buffer: SymmBuffer, cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, - recipe: Tuple[int, int, int] = (1, 1, 32), + recipe: Tuple[int, int, int] = (1, 1, 16), activation: str = 'swiglu', activation_clamp: Optional[float] = None, fast_math: bool = True):