fix: L1 output uses unpacked E2M1 (1 byte/element) like FP8

- float_e2m1_unpacksmem_t: sizeof=1, SMEM is 1 byte/element (not packed)
- TMA load unpacks 2 E2M1/global-byte → 2 SMEM bytes
- UMMA reads unpacked SMEM, packs internally for mxf4nvf4
- L1→L2 handoff: unpacked format (same byte count as FP8)
- Epilogue: 4 E2M1 bytes per uint32 STSM atom, same as FP8
- Dispatch TMA: kHidden bytes (unpacked), not kHidden/2
- Added static_assert on sizeof(a_dtype_t) and sizeof(b_dtype_t)
- Note: no bandwidth savings at L1→L2 boundary for v1
This commit is contained in:
2026-05-11 21:27:35 +00:00
parent 091b974736
commit 0ac73a82f9
3 changed files with 52 additions and 46 deletions

View File

@@ -136,12 +136,12 @@ static void sm100_fp8_nvfp4_mega_moe(
// Make tensormap — weight/activation TMA descriptors are the same as MXFP4
// (E2M1 packed uint8 is the same format regardless of scale type)
// NVFP4: activations are E2M1 packed uint8, so K-dim is hidden/2 bytes
// L1 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8
const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts,
hidden / 2, config.num_max_pool_tokens,
config.block_k / 2, config.load_block_m,
hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
static_cast<int>(l1_acts.stride(-2)),
config.swizzle_acts_mode / 2);
config.swizzle_acts_mode);
// NVFP4 SF TMA: kGranK=16, so SF K-dim is hidden/16, packed as hidden/64 int32
const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf,
config.num_padded_sf_pool_tokens, hidden,
@@ -156,18 +156,18 @@ static void sm100_fp8_nvfp4_mega_moe(
intermediate_hidden * 2, hidden,
config.block_n, kGranK,
num_experts_per_rank, 0);
// NVFP4: L1 output is E2M1 packed, intermediate_hidden/2 bytes per row
// L1 output: unpacked E2M1 (1 byte/element), same dimensions as FP8
const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts,
intermediate_hidden / 2, config.num_max_pool_tokens,
config.block_n / 4, config.store_block_m,
intermediate_hidden, config.num_max_pool_tokens,
config.block_n / 2, config.store_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode / 4);
// NVFP4: L2 activations are E2M1 packed, intermediate_hidden/2 bytes per row
config.swizzle_acts_mode / 2);
// L2 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8
const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts,
intermediate_hidden / 2, config.num_max_pool_tokens,
config.block_k / 2, config.load_block_m,
intermediate_hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode / 2);
config.swizzle_acts_mode);
const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf,
config.num_padded_sf_pool_tokens, intermediate_hidden,
config.sf_block_m, kGranK,