fix: use scale_vec::2X (block32) for SM100 B200 compatibility
scale_vec::4X (block16) requires SM103/SM120 (B300/GB300), not SM100 (B200). Revert to block32 with UE4M3 scales. Same TMEM layout as MXFP4 but with UE4M3 scale format instead of UE8M0. Changes: - kGranK: 16 → 32 - PTX: scale_vec::4X → scale_vec::2X - SF layout: same as MXFP4 (K/32, K/128 for int32 packed) - UTCCP: i*8 → i*4 (2X layout, same as MXFP4) - TMEM columns: same as MXFP4 (SF_BLOCK_M/32, SF_BLOCK_N/32) - Python: merge NVFP4 block16→block32 scales (max of adjacent pairs) - recipe: (1,1,16) → (1,1,32)
This commit is contained in:
@@ -53,8 +53,8 @@ static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
|
||||
}
|
||||
|
||||
// (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major
|
||||
// Supports gran_k=16 (NVFP4), 32 (MXFP4), 128 (FP8)
|
||||
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 16 or gran_k == 32 or gran_k == 128) and arch_major == 10)
|
||||
// Supports gran_k=32 (MXFP4 and NVFP4-block32), 128 (FP8)
|
||||
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10)
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown SF transformation");
|
||||
|
||||
@@ -30,8 +30,8 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
const auto fp8_token_layout = layout::Data(hidden);
|
||||
const auto bf16_token_layout = layout::Data(hidden * 2);
|
||||
const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden);
|
||||
const auto nvfp4_sf_layout = layout::Data(hidden / 16);
|
||||
const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16);
|
||||
const auto nvfp4_sf_layout = layout::Data(hidden / 32);
|
||||
const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 32);
|
||||
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
|
||||
const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false);
|
||||
const auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
|
||||
@@ -86,7 +86,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
|
||||
// Check SF buffer requirements
|
||||
// NVFP4: hidden must be divisible by 64 (4 UE4M3 scales per int32, group_size=16)
|
||||
DG_HOST_ASSERT(hidden % 64 == 0 and intermediate_hidden % 64 == 0);
|
||||
DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0);
|
||||
DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0);
|
||||
|
||||
// Slice function
|
||||
@@ -98,7 +98,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
// NVFP4 SF: K/16 bytes per token, packed as K/64 int32
|
||||
auto x_sf = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
|
||||
{num_max_tokens_per_rank, hidden / 64},
|
||||
{num_max_tokens_per_rank, hidden / 128},
|
||||
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
|
||||
auto topk_idx = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_idx_buffer.base)),
|
||||
@@ -115,7 +115,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
// NVFP4 L1 SF: M-major, K/64 int32
|
||||
auto l1_acts_sf = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
|
||||
{num_max_padded_sf_pool_tokens, hidden / 64},
|
||||
{num_max_padded_sf_pool_tokens, hidden / 128},
|
||||
{1, num_max_padded_sf_pool_tokens},
|
||||
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
|
||||
auto l2_acts = torch::from_blob(
|
||||
@@ -125,7 +125,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
// NVFP4 L2 SF: M-major, K/64 int32
|
||||
auto l2_acts_sf = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),
|
||||
{num_max_padded_sf_pool_tokens, intermediate_hidden / 64},
|
||||
{num_max_padded_sf_pool_tokens, intermediate_hidden / 128},
|
||||
{1, num_max_padded_sf_pool_tokens},
|
||||
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
|
||||
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
|
||||
@@ -153,7 +153,7 @@ static void fp8_nvfp4_mega_moe(
|
||||
// Config checks
|
||||
const auto num_tokens = static_cast<int>(y.size(0));
|
||||
const auto [rm, rn, rk] = recipe;
|
||||
DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 16); // NVFP4: group_size=16
|
||||
DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 32); // NVFP4 block32: group_size=32
|
||||
DG_HOST_ASSERT(activation == "swiglu");
|
||||
|
||||
// Activation checks
|
||||
|
||||
@@ -98,9 +98,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
constexpr auto fp8_token_layout = layout::Data(kHidden);
|
||||
constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16));
|
||||
constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden);
|
||||
// NVFP4: group_size=16, so SF stride is K/16 (twice as many scales as MXFP4)
|
||||
constexpr auto fp8_sf_layout = layout::Data(kHidden / 16);
|
||||
constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16);
|
||||
// NVFP4: scale_vec::2X (block32) on SM100, same SF stride as MXFP4
|
||||
constexpr auto fp8_sf_layout = layout::Data(kHidden / 32);
|
||||
constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 32);
|
||||
constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false);
|
||||
constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false);
|
||||
constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
|
||||
@@ -120,8 +120,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
input_topk_idx_buffer.get_end_ptr());
|
||||
|
||||
// SF and its buffer configs
|
||||
// NVFP4: group_size=16 → kGranK=16 (vs MXFP4's 32)
|
||||
constexpr uint32_t kGranK = 16;
|
||||
// NVFP4 on SM100: scale_vec::2X (block32), group_size=32 with UE4M3 scales
|
||||
// Note: scale_vec::4X (block16) requires SM103/SM120 (B300/GB300), not SM100
|
||||
// So we use block32 and merge pairs of NVFP4 block16 scales
|
||||
constexpr uint32_t kGranK = 32;
|
||||
// For NVFP4 scale_vec::4X, UTCCP alignment is still 128 elements
|
||||
constexpr uint32_t kNumUTCCPAlignedElems = 128;
|
||||
DG_STATIC_ASSERT(SF_BLOCK_M == math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems), "Invalid SF_BLOCK_M");
|
||||
@@ -220,11 +222,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// Tensor memory size
|
||||
constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages;
|
||||
// NVFP4: scale_vec::4X → 4 SF per UMMA atom row → 4 TMEM cols per SF row
|
||||
// For bM=128, SFA uses 4 rows × 4 cols = 16 TMEM columns
|
||||
// SFB uses BLOCK_N/32 rows × 4 cols
|
||||
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32 * 4;
|
||||
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32 * 4;
|
||||
// NVFP4 scale_vec::2X: same TMEM layout as MXFP4
|
||||
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
|
||||
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
|
||||
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
|
||||
constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
|
||||
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
|
||||
@@ -563,9 +563,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
__syncwarp();
|
||||
|
||||
// Load and store SF (overlaps with TMA token load)
|
||||
// NVFP4: group_size=16, 4 UE4M3 scales per uint32
|
||||
constexpr uint32_t kNumSFUint32 = kHidden / 64;
|
||||
DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 64 == 0, "Invalid SF");
|
||||
// NVFP4 block32: same SF uint32 count as MXFP4
|
||||
constexpr uint32_t kNumSFUint32 = kHidden / 128;
|
||||
DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 128 == 0, "Invalid SF");
|
||||
const auto remote_sf_ptr = sym_buffer.map(
|
||||
input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr<uint32_t>(),
|
||||
current_rank_in_expert_idx);
|
||||
@@ -846,21 +846,19 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx);
|
||||
if (cute::elect_one_sync()) {
|
||||
// UTCCP copy SFA and SFB to TMEM
|
||||
// NVFP4: scale_vec::4X, each 128-element block → 8 TMEM cols
|
||||
// NVFP4 scale_vec::2X: same layout as MXFP4
|
||||
using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
|
||||
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
// NVFP4 4X: 8 TMEM columns per 128-element SF group
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 8);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
|
||||
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 8);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
|
||||
}
|
||||
|
||||
// Issue UMMA
|
||||
|
||||
@@ -153,7 +153,7 @@ struct SM100_MMA_MXF4NVF4_2x1SM_SS {
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
"setp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t"
|
||||
"tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t"
|
||||
"}\n"
|
||||
:
|
||||
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
|
||||
@@ -175,7 +175,7 @@ struct SM100_MMA_MXF4NVF4_SS {
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
"setp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t"
|
||||
"tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t"
|
||||
"}\n"
|
||||
:
|
||||
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
|
||||
|
||||
@@ -162,6 +162,20 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2)
|
||||
l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2)
|
||||
|
||||
# Merge NVFP4 block16 scales → block32 for SM100 (scale_vec::2X)
|
||||
# B200 (SM100) doesn't support scale_vec::4X (block16) — requires SM103/SM120
|
||||
# Take the max of each pair of adjacent block16 scales for block32
|
||||
def merge_block16_to_block32(sf):
|
||||
# sf: (experts, mn, K//16) float8_e4m3fn
|
||||
# output: (experts, mn, K//32) float8_e4m3fn
|
||||
sf_f32 = sf.to(torch.float32)
|
||||
# Take max of adjacent pairs (preserves magnitude, avoids underflow)
|
||||
sf_merged = torch.maximum(sf_f32[..., 0::2], sf_f32[..., 1::2])
|
||||
return sf_merged.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
||||
|
||||
l1_sf_32 = merge_block16_to_block32(l1_sf)
|
||||
l2_sf_32 = merge_block16_to_block32(l2_sf)
|
||||
|
||||
num_experts = l1_weights[0].shape[0]
|
||||
l1_n = l1_weights[0].shape[1]
|
||||
l1_k = l1_weights[0].shape[2] * 2
|
||||
@@ -179,8 +193,8 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
(sf_u8[..., 3::4].to(torch.int32) << 24))
|
||||
return packed.contiguous()
|
||||
|
||||
l1_sf_packed = pack_ue4m3_to_int32(l1_sf)
|
||||
l2_sf_packed = pack_ue4m3_to_int32(l2_sf)
|
||||
l1_sf_packed = pack_ue4m3_to_int32(l1_sf_32)
|
||||
l2_sf_packed = pack_ue4m3_to_int32(l2_sf_32)
|
||||
|
||||
# Transpose to MN-major layout (stride(-2)=1) and make contiguous
|
||||
# transform_sf_into_required_layout expects MN-major input for TMA stride checks
|
||||
@@ -188,11 +202,11 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
l2_sf_mn = l2_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1)
|
||||
|
||||
# Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function
|
||||
# recipe (1, 16): gran_mn=1, gran_k=16
|
||||
# recipe (1, 32): gran_mn=1, gran_k=16
|
||||
l1_sf_transformed = transform_sf_into_required_layout(
|
||||
l1_sf_mn, l1_n, l1_k, (1, 16), num_experts)
|
||||
l1_sf_mn, l1_n, l1_k, (1, 32), num_experts)
|
||||
l2_sf_transformed = transform_sf_into_required_layout(
|
||||
l2_sf_mn, l2_n, l2_k, (1, 16), num_experts)
|
||||
l2_sf_mn, l2_n, l2_k, (1, 32), num_experts)
|
||||
|
||||
# L1: interleave gate/up
|
||||
l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_packed))
|
||||
@@ -267,7 +281,7 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor,
|
||||
l2_weights: Tuple[torch.Tensor, torch.Tensor],
|
||||
sym_buffer: SymmBuffer,
|
||||
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
|
||||
recipe: Tuple[int, int, int] = (1, 1, 16),
|
||||
recipe: Tuple[int, int, int] = (1, 1, 32),
|
||||
activation: str = 'swiglu',
|
||||
activation_clamp: Optional[float] = None,
|
||||
fast_math: bool = True):
|
||||
|
||||
Reference in New Issue
Block a user