fix: packed FP4 for mxf4nvf4 — correct SMEM layout, UMMA descriptors, L1 epilogue

Key changes:
- a_dtype_t/b_dtype_t: float_e2m1_t (packed 4-bit) with sizeof_bits_v==4 assert
- kSwizzleAMode/BMode: BLOCK_K/2 (64 bytes packed, not 128 unpacked)
- SMEM sizes: LOAD_BLOCK_M * BLOCK_K / 2 (packed byte count)
- Token layouts: kHidden/2, kIntermediateHidden/2 (packed bytes)
- TMA loads: BLOCK_K/2 inner dim, uint8_t, byte offsets k_block_idx*(BLOCK_K/2)
- UMMA descriptors: BLOCK_K/2 template param, uint8_t dtype, UMMA_K/2 advance
- L1 epilogue: dropped STSM, direct st.shared.u16 with packed nibbles, no swizzle (v1)
- Pybind buffer sizes: hidden/2, intermediate_hidden/2 with packed tensor shapes
- Host TMA descriptors: hidden/2 K-dims, block_k/2 inner, fp4_unpacked_smem=false
- L1 output TMA: block_n/4 inner, no swizzle (CU_TENSOR_MAP_SWIZZLE_NONE)
This commit is contained in:
2026-05-11 21:59:21 +00:00
parent 0ac73a82f9
commit 30d72e7ef5
3 changed files with 105 additions and 96 deletions

View File

@@ -28,9 +28,9 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// NVFP4 layouts: E2M1 packed (2 per byte), so token layout is K/2 bytes
// group_size=16, so SF stride is K/16 (twice as many as MXFP4)
const auto fp4_token_layout = layout::Data(hidden);
const auto fp4_token_layout = layout::Data(hidden / 2);
const auto bf16_token_layout = layout::Data(hidden * 2);
const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden);
const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden / 2);
const auto nvfp4_sf_layout = layout::Data(hidden / 16);
const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16);
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
@@ -95,7 +95,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// NVFP4: E2M1 packed activations (2 per byte), K/2 bytes per token
auto x = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
{num_max_tokens_per_rank, hidden},
{num_max_tokens_per_rank, hidden / 2}, // packed: hidden elements = hidden/2 bytes
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
// NVFP4 SF: K/16 bytes per token, packed as K/64 int32
auto x_sf = torch::from_blob(
@@ -110,10 +110,10 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_weights_buffer.base)),
{num_max_tokens_per_rank, num_topk},
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
// NVFP4: L1 output acts are E2M1 packed, intermediate_hidden/2 bytes
// NVFP4: L1 output acts are E2M1 packed, hidden/2 bytes
auto l1_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
{num_max_pool_tokens, hidden},
{num_max_pool_tokens, hidden / 2}, // packed: hidden elements = hidden/2 bytes
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
// NVFP4 L1 SF: M-major, K/64 int32
auto l1_acts_sf = torch::from_blob(
@@ -124,7 +124,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// NVFP4: L2 acts are E2M1 packed, intermediate_hidden/2 bytes
auto l2_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
{num_max_pool_tokens, intermediate_hidden},
{num_max_pool_tokens, intermediate_hidden / 2}, // packed: elements/2 bytes
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
// NVFP4 L2 SF: M-major, K/64 int32
auto l2_acts_sf = torch::from_blob(

View File

@@ -134,14 +134,14 @@ static void sm100_fp8_nvfp4_mega_moe(
// NVFP4: kGranK=16 for group_size=16
constexpr int kGranK = 16;
// Make tensormap — weight/activation TMA descriptors are the same as MXFP4
// (E2M1 packed uint8 is the same format regardless of scale type)
// L1 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8
// Make tensormap — NVFP4 packed E2M1 (2 per byte), so K-dim is hidden/2 bytes
// L1 activations: packed E2M1 (4 bits/element), K-dim = hidden/2
const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts,
hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
hidden / 2, config.num_max_pool_tokens,
config.block_k / 2, config.load_block_m,
static_cast<int>(l1_acts.stride(-2)),
config.swizzle_acts_mode);
config.swizzle_acts_mode / 2,
0, false, false); // fp4_unpacked_smem=false
// NVFP4 SF TMA: kGranK=16, so SF K-dim is hidden/16, packed as hidden/64 int32
const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf,
config.num_padded_sf_pool_tokens, hidden,
@@ -156,18 +156,21 @@ static void sm100_fp8_nvfp4_mega_moe(
intermediate_hidden * 2, hidden,
config.block_n, kGranK,
num_experts_per_rank, 0);
// L1 output: unpacked E2M1 (1 byte/element), same dimensions as FP8
// L1 output: packed E2M1, K-dim = intermediate_hidden/2, inner = block_n/4 bytes, no swizzle (v1)
const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts,
intermediate_hidden, config.num_max_pool_tokens,
config.block_n / 2, config.store_block_m,
intermediate_hidden / 2, config.num_max_pool_tokens,
config.block_n / 4, config.store_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode / 2);
// L2 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8
0, 0, // no swizzle
false, // allow_tf32
false); // fp4_unpacked_smem=false (packed!)
// L2 activations: packed E2M1, K-dim = intermediate_hidden/2
const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts,
intermediate_hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
intermediate_hidden / 2, config.num_max_pool_tokens,
config.block_k / 2, config.load_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode);
config.swizzle_acts_mode / 2,
0, false, false); // fp4_unpacked_smem=false
const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf,
config.num_padded_sf_pool_tokens, intermediate_hidden,
config.sf_block_m, kGranK,