fix: use mxf8f6f4 (UE8M0) on SM100 — mxf4nvf4 requires SM103+
B200 (SM100) does NOT support kind::mxf4nvf4 at all (neither 2X nor 4X). Only mxf8f6f4.block_scale with UE8M0 scales is available on SM100. Strategy: keep NVFP4 E2M1 weights, convert UE4M3 block scales → UE8M0 in the weight transformation. This is a scale format adaptation for hardware compatibility, not a format conversion. Changes: - Kernel: back to mxf8f6F4 instruction + float_ue8m0_t descriptor - L1 epilogue: back to UE8M0 (>> 23) activation scales - Python: merge block16→block32, convert UE4M3→float32→UE8M0 - Packing: uint8 (UE8M0) → int32, same as MXFP4
This commit is contained in:
@@ -785,10 +785,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// GEMM MMA issue warp (only the leader CTA will run)
|
||||
if (is_leader_cta) {
|
||||
// NVFP4: use float_ue4m3_t scale factor type with mxf4nvf4 instruction
|
||||
// NOTES: always swap A/B
|
||||
// NVFP4 on SM100: use mxf8f6f4 instruction with UE8M0 scales
|
||||
// (mxf4nvf4 requires SM103+; B200 is SM100)
|
||||
// We convert UE4M3→UE8M0 in the weight transformation
|
||||
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<
|
||||
b_dtype_t, a_dtype_t, float, cutlass::float_ue4m3_t,
|
||||
b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t,
|
||||
UMMA_M, UMMA_N,
|
||||
cute::UMMA::Major::K, cute::UMMA::Major::K
|
||||
>();
|
||||
@@ -870,8 +871,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K);
|
||||
b_desc.lo = mma::sm100::advance_umma_desc_lo<
|
||||
cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
|
||||
// NVFP4: use mxf4nvf4 instruction with UE4M3 scales
|
||||
ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma(
|
||||
ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma(
|
||||
b_desc, a_desc, accum_stage_idx * UMMA_N,
|
||||
k_block_idx > 0 or k > 0, runtime_instr_desc,
|
||||
kTmemStartColOfSFB, kTmemStartColOfSFA);
|
||||
@@ -1097,15 +1097,12 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M
|
||||
+ m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4;
|
||||
const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx;
|
||||
// NVFP4: convert float scale to UE4M3 format
|
||||
// UE4M3: sign=0 + 4 exp + 3 mantissa, max=448
|
||||
auto to_ue4m3 = [](float v) -> uint8_t {
|
||||
v = fmaxf(0.0f, fminf(v, 448.0f));
|
||||
cutlass::float_e4m3_t e4m3_val = cutlass::float_e4m3_t(v);
|
||||
return reinterpret_cast<uint8_t&>(e4m3_val) & 0x7F;
|
||||
};
|
||||
sf_base_ptr[sf_addr] = to_ue4m3(sf.x);
|
||||
sf_base_ptr[sf_addr + 4 * static_cast<uint32_t>(sizeof(uint32_t))] = to_ue4m3(sf.y);
|
||||
// NVFP4 on SM100: convert float scale to UE8M0 (power-of-2)
|
||||
// UE8M0: 8-bit exponent, no mantissa, represents 2^(exp-127)
|
||||
sf_base_ptr[sf_addr] =
|
||||
(*reinterpret_cast<const uint32_t*>(&sf.x) >> 23);
|
||||
sf_base_ptr[sf_addr + 4 * static_cast<uint32_t>(sizeof(uint32_t))] =
|
||||
(*reinterpret_cast<const uint32_t*>(&sf.y) >> 23);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user