From 03b8c99ee1885a69050cbf98d39136256366f6c8 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 09:28:45 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20use=20mxf8f6f4=20(UE8M0)=20on=20SM100=20?= =?UTF-8?q?=E2=80=94=20mxf4nvf4=20requires=20SM103+?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit B200 (SM100) does NOT support kind::mxf4nvf4 at all (neither 2X nor 4X). Only mxf8f6f4.block_scale with UE8M0 scales is available on SM100. Strategy: keep NVFP4 E2M1 weights, convert UE4M3 block scales → UE8M0 in the weight transformation. This is a scale format adaptation for hardware compatibility, not a format conversion. Changes: - Kernel: back to mxf8f6F4 instruction + float_ue8m0_t descriptor - L1 epilogue: back to UE8M0 (>> 23) activation scales - Python: merge block16→block32, convert UE4M3→float32→UE8M0 - Packing: uint8 (UE8M0) → int32, same as MXFP4 --- .../impls/sm100_fp8_nvfp4_mega_moe.cuh | 25 ++++++-------- deep_gemm/mega/__init__.py | 34 +++++++++++-------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 1becc65..95f2aeb 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -785,10 +785,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // GEMM MMA issue warp (only the leader CTA will run) if (is_leader_cta) { - // NVFP4: use float_ue4m3_t scale factor type with mxf4nvf4 instruction - // NOTES: always swap A/B + // NVFP4 on SM100: use mxf8f6f4 instruction with UE8M0 scales + // (mxf4nvf4 requires SM103+; B200 is SM100) + // We convert UE4M3→UE8M0 in the weight transformation auto instr_desc = cute::UMMA::make_instr_desc_block_scaled< - b_dtype_t, a_dtype_t, float, cutlass::float_ue4m3_t, + b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t, UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K >(); @@ -870,8 +871,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K); b_desc.lo = mma::sm100::advance_umma_desc_lo< cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K); - // NVFP4: use mxf4nvf4 instruction with UE4M3 scales - ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma( + ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma( b_desc, a_desc, accum_stage_idx * UMMA_N, k_block_idx > 0 or k > 0, runtime_instr_desc, kTmemStartColOfSFB, kTmemStartColOfSFA); @@ -1097,15 +1097,12 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4; const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast(sizeof(uint32_t)) + byte_idx; - // NVFP4: convert float scale to UE4M3 format - // UE4M3: sign=0 + 4 exp + 3 mantissa, max=448 - auto to_ue4m3 = [](float v) -> uint8_t { - v = fmaxf(0.0f, fminf(v, 448.0f)); - cutlass::float_e4m3_t e4m3_val = cutlass::float_e4m3_t(v); - return reinterpret_cast(e4m3_val) & 0x7F; - }; - sf_base_ptr[sf_addr] = to_ue4m3(sf.x); - sf_base_ptr[sf_addr + 4 * static_cast(sizeof(uint32_t))] = to_ue4m3(sf.y); + // NVFP4 on SM100: convert float scale to UE8M0 (power-of-2) + // UE8M0: 8-bit exponent, no mantissa, represents 2^(exp-127) + sf_base_ptr[sf_addr] = + (*reinterpret_cast(&sf.x) >> 23); + sf_base_ptr[sf_addr + 4 * static_cast(sizeof(uint32_t))] = + (*reinterpret_cast(&sf.y) >> 23); } __syncwarp(); } diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index ea2a103..5417f0d 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -167,11 +167,17 @@ def transform_nvfp4_weights_for_mega_moe( # Take the max of each pair of adjacent block16 scales for block32 def merge_block16_to_block32(sf): # sf: (experts, mn, K//16) float8_e4m3fn - # output: (experts, mn, K//32) float8_e4m3fn + # output: (experts, mn, K//32) uint8 (UE8M0) + # SM100 (B200) doesn't support mxf4nvf4 — must use mxf8f6f4 with UE8M0 scales + # Convert UE4M3 → float32 → UE8M0 (power-of-2) sf_f32 = sf.to(torch.float32) - # Take max of adjacent pairs (preserves magnitude, avoids underflow) + # Take max of adjacent pairs sf_merged = torch.maximum(sf_f32[..., 0::2], sf_f32[..., 1::2]) - return sf_merged.clamp(0.0, 448.0).to(torch.float8_e4m3fn) + # Convert to UE8M0: 2^(floor(log2(v)) - 127 + 127) = extract exponent byte + # UE8M0 encoding: uint8 = float32_exponent_bits >> 23 + sf_u32 = sf_merged.view(torch.uint32) + sf_ue8m0 = (sf_u32 >> 23).to(torch.uint8) + return sf_ue8m0 l1_sf_32 = merge_block16_to_block32(l1_sf) l2_sf_32 = merge_block16_to_block32(l2_sf) @@ -182,19 +188,19 @@ def transform_nvfp4_weights_for_mega_moe( l2_n = l2_weights[0].shape[1] l2_k = l2_weights[0].shape[2] * 2 - # Pack UE4M3 (float8_e4m3fn) into int32 for DeepGEMM TMA consumption - # 4 UE4M3 bytes → 1 int32, matching the hardware's 4X scale vector - def pack_ue4m3_to_int32(sf): - sf_u8 = sf.view(torch.uint8) - assert sf_u8.shape[-1] % 4 == 0 - packed = (sf_u8[..., 0::4].to(torch.int32) | - (sf_u8[..., 1::4].to(torch.int32) << 8) | - (sf_u8[..., 2::4].to(torch.int32) << 16) | - (sf_u8[..., 3::4].to(torch.int32) << 24)) + # Pack UE8M0 (uint8) block scales into int32 for DeepGEMM TMA consumption + # Same packing as MXFP4: 4 uint8 → 1 int32 + def pack_uint8_to_int32(sf): + assert sf.dtype == torch.uint8 + assert sf.shape[-1] % 4 == 0 + packed = (sf[..., 0::4].to(torch.int32) | + (sf[..., 1::4].to(torch.int32) << 8) | + (sf[..., 2::4].to(torch.int32) << 16) | + (sf[..., 3::4].to(torch.int32) << 24)) return packed.contiguous() - l1_sf_packed = pack_ue4m3_to_int32(l1_sf_32) - l2_sf_packed = pack_ue4m3_to_int32(l2_sf_32) + l1_sf_packed = pack_uint8_to_int32(l1_sf_32) + l2_sf_packed = pack_uint8_to_int32(l2_sf_32) print(f"[NVFP4-MoE] l1_sf_32: shape={l1_sf_32.shape}, l1_sf_packed: shape={l1_sf_packed.shape}") print(f"[NVFP4-MoE] l2_sf_32: shape={l2_sf_32.shape}, l2_sf_packed: shape={l2_sf_packed.shape}")