revert: restore mxf4nvf4/block16 code (correct path for sm_100a)

Reverted to commit 36b439e's NVFP4 kernel code:
- kGranK=16, mxf4nvf4.block_scale.scale_vec::4X
- float_ue4m3_t instruction descriptor
- Block16 SF layout (4X TMEM)
- UE4M3 L1 epilogue
- No UE4M3→UE8M0 conversion, no block16→block32 merge

The mxf4nvf4.scale_vec::4X PTX instruction compiles successfully
on both sm_100 and sm_100f with CUDA 13.0. The previous build 17
error was likely from a different cause, not the arch flag.

Python: reverted transform_nvfp4_weights_for_mega_moe to use
pack_ue4m3_to_int32 with gran_k=16, no UE8M0 conversion.
This commit is contained in:
2026-05-11 15:02:47 +00:00
parent e80fe9af60
commit fbdddaccf4
5 changed files with 56 additions and 160 deletions

View File

@@ -98,9 +98,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
constexpr auto fp8_token_layout = layout::Data(kHidden);
constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16));
constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden);
// NVFP4: scale_vec::2X (block32) on SM100, same SF stride as MXFP4
constexpr auto fp8_sf_layout = layout::Data(kHidden / 32);
constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 32);
// NVFP4: group_size=16, so SF stride is K/16 (twice as many scales as MXFP4)
constexpr auto fp8_sf_layout = layout::Data(kHidden / 16);
constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16);
constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false);
constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false);
constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
@@ -120,10 +120,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
input_topk_idx_buffer.get_end_ptr());
// SF and its buffer configs
// NVFP4 on SM100: scale_vec::2X (block32), group_size=32 with UE4M3 scales
// Note: scale_vec::4X (block16) requires SM103/SM120 (B300/GB300), not SM100
// So we use block32 and merge pairs of NVFP4 block16 scales
constexpr uint32_t kGranK = 32;
// NVFP4: group_size=16 → kGranK=16 (vs MXFP4's 32)
constexpr uint32_t kGranK = 16;
// For NVFP4 scale_vec::4X, UTCCP alignment is still 128 elements
constexpr uint32_t kNumUTCCPAlignedElems = 128;
DG_STATIC_ASSERT(SF_BLOCK_M == math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems), "Invalid SF_BLOCK_M");
@@ -222,9 +220,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// Tensor memory size
constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages;
// NVFP4 scale_vec::2X: same TMEM layout as MXFP4
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
// NVFP4: scale_vec::4X → 4 SF per UMMA atom row → 4 TMEM cols per SF row
// For bM=128, SFA uses 4 rows × 4 cols = 16 TMEM columns
// SFB uses BLOCK_N/32 rows × 4 cols
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32 * 4;
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32 * 4;
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
@@ -563,9 +563,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
__syncwarp();
// Load and store SF (overlaps with TMA token load)
// NVFP4 block32: same SF uint32 count as MXFP4
constexpr uint32_t kNumSFUint32 = kHidden / 128;
DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 128 == 0, "Invalid SF");
// NVFP4: group_size=16, 4 UE4M3 scales per uint32
constexpr uint32_t kNumSFUint32 = kHidden / 64;
DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 64 == 0, "Invalid SF");
const auto remote_sf_ptr = sym_buffer.map(
input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr<uint32_t>(),
current_rank_in_expert_idx);
@@ -785,11 +785,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// GEMM MMA issue warp (only the leader CTA will run)
if (is_leader_cta) {
// NVFP4 on SM100: use mxf8f6f4 instruction with UE8M0 scales
// (mxf4nvf4 requires SM103+; B200 is SM100)
// We convert UE4M3→UE8M0 in the weight transformation
// NVFP4: use float_ue4m3_t scale factor type with mxf4nvf4 instruction
// NOTES: always swap A/B
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<
b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t,
b_dtype_t, a_dtype_t, float, cutlass::float_ue4m3_t,
UMMA_M, UMMA_N,
cute::UMMA::Major::K, cute::UMMA::Major::K
>();
@@ -847,19 +846,21 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx);
if (cute::elect_one_sync()) {
// UTCCP copy SFA and SFB to TMEM
// NVFP4 scale_vec::2X: same layout as MXFP4
// NVFP4: scale_vec::4X, each 128-element block → 8 TMEM cols
using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta;
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
// NVFP4 4X: 8 TMEM columns per 128-element SF group
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 8);
}
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 8);
}
// Issue UMMA
@@ -871,7 +872,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K);
b_desc.lo = mma::sm100::advance_umma_desc_lo<
cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma(
// NVFP4: use mxf4nvf4 instruction with UE4M3 scales
ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma(
b_desc, a_desc, accum_stage_idx * UMMA_N,
k_block_idx > 0 or k > 0, runtime_instr_desc,
kTmemStartColOfSFB, kTmemStartColOfSFA);
@@ -1097,12 +1099,15 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M
+ m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4;
const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx;
// NVFP4 on SM100: convert float scale to UE8M0 (power-of-2)
// UE8M0: 8-bit exponent, no mantissa, represents 2^(exp-127)
sf_base_ptr[sf_addr] =
(*reinterpret_cast<const uint32_t*>(&sf.x) >> 23);
sf_base_ptr[sf_addr + 4 * static_cast<uint32_t>(sizeof(uint32_t))] =
(*reinterpret_cast<const uint32_t*>(&sf.y) >> 23);
// NVFP4: convert float scale to UE4M3 format
// UE4M3: sign=0 + 4 exp + 3 mantissa, max=448
auto to_ue4m3 = [](float v) -> uint8_t {
v = fmaxf(0.0f, fminf(v, 448.0f));
cutlass::float_e4m3_t e4m3_val = cutlass::float_e4m3_t(v);
return reinterpret_cast<uint8_t&>(e4m3_val) & 0x7F;
};
sf_base_ptr[sf_addr] = to_ue4m3(sf.x);
sf_base_ptr[sf_addr + 4 * static_cast<uint32_t>(sizeof(uint32_t))] = to_ue4m3(sf.y);
}
__syncwarp();
}

View File

@@ -153,7 +153,7 @@ struct SM100_MMA_MXF4NVF4_2x1SM_SS {
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t"
"tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
@@ -175,7 +175,7 @@ struct SM100_MMA_MXF4NVF4_SS {
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t"
"tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),