fix: L1 output uses unpacked E2M1 (1 byte/element) like FP8

- float_e2m1_unpacksmem_t: sizeof=1, SMEM is 1 byte/element (not packed)
- TMA load unpacks 2 E2M1/global-byte → 2 SMEM bytes
- UMMA reads unpacked SMEM, packs internally for mxf4nvf4
- L1→L2 handoff: unpacked format (same byte count as FP8)
- Epilogue: 4 E2M1 bytes per uint32 STSM atom, same as FP8
- Dispatch TMA: kHidden bytes (unpacked), not kHidden/2
- Added static_assert on sizeof(a_dtype_t) and sizeof(b_dtype_t)
- Note: no bandwidth savings at L1→L2 boundary for v1
This commit is contained in:
2026-05-11 21:27:35 +00:00
parent 091b974736
commit 0ac73a82f9
3 changed files with 52 additions and 46 deletions

View File

@@ -28,9 +28,9 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// NVFP4 layouts: E2M1 packed (2 per byte), so token layout is K/2 bytes
// group_size=16, so SF stride is K/16 (twice as many as MXFP4)
const auto fp4_token_layout = layout::Data(hidden / 2);
const auto fp4_token_layout = layout::Data(hidden);
const auto bf16_token_layout = layout::Data(hidden * 2);
const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden / 2);
const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden);
const auto nvfp4_sf_layout = layout::Data(hidden / 16);
const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16);
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
@@ -95,7 +95,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// NVFP4: E2M1 packed activations (2 per byte), K/2 bytes per token
auto x = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
{num_max_tokens_per_rank, hidden / 2},
{num_max_tokens_per_rank, hidden},
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
// NVFP4 SF: K/16 bytes per token, packed as K/64 int32
auto x_sf = torch::from_blob(
@@ -113,7 +113,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// NVFP4: L1 output acts are E2M1 packed, intermediate_hidden/2 bytes
auto l1_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
{num_max_pool_tokens, hidden / 2},
{num_max_pool_tokens, hidden},
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
// NVFP4 L1 SF: M-major, K/64 int32
auto l1_acts_sf = torch::from_blob(
@@ -124,7 +124,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// NVFP4: L2 acts are E2M1 packed, intermediate_hidden/2 bytes
auto l2_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
{num_max_pool_tokens, intermediate_hidden / 2},
{num_max_pool_tokens, intermediate_hidden},
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
// NVFP4 L2 SF: M-major, K/64 int32
auto l2_acts_sf = torch::from_blob(

View File

@@ -136,12 +136,12 @@ static void sm100_fp8_nvfp4_mega_moe(
// Make tensormap — weight/activation TMA descriptors are the same as MXFP4
// (E2M1 packed uint8 is the same format regardless of scale type)
// NVFP4: activations are E2M1 packed uint8, so K-dim is hidden/2 bytes
// L1 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8
const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts,
hidden / 2, config.num_max_pool_tokens,
config.block_k / 2, config.load_block_m,
hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
static_cast<int>(l1_acts.stride(-2)),
config.swizzle_acts_mode / 2);
config.swizzle_acts_mode);
// NVFP4 SF TMA: kGranK=16, so SF K-dim is hidden/16, packed as hidden/64 int32
const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf,
config.num_padded_sf_pool_tokens, hidden,
@@ -156,18 +156,18 @@ static void sm100_fp8_nvfp4_mega_moe(
intermediate_hidden * 2, hidden,
config.block_n, kGranK,
num_experts_per_rank, 0);
// NVFP4: L1 output is E2M1 packed, intermediate_hidden/2 bytes per row
// L1 output: unpacked E2M1 (1 byte/element), same dimensions as FP8
const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts,
intermediate_hidden / 2, config.num_max_pool_tokens,
config.block_n / 4, config.store_block_m,
intermediate_hidden, config.num_max_pool_tokens,
config.block_n / 2, config.store_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode / 4);
// NVFP4: L2 activations are E2M1 packed, intermediate_hidden/2 bytes per row
config.swizzle_acts_mode / 2);
// L2 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8
const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts,
intermediate_hidden / 2, config.num_max_pool_tokens,
config.block_k / 2, config.load_block_m,
intermediate_hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode / 2);
config.swizzle_acts_mode);
const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf,
config.num_padded_sf_pool_tokens, intermediate_hidden,
config.sf_block_m, kGranK,