fix: packed FP4 for mxf4nvf4 — correct SMEM layout, UMMA descriptors, L1 epilogue
Key changes: - a_dtype_t/b_dtype_t: float_e2m1_t (packed 4-bit) with sizeof_bits_v==4 assert - kSwizzleAMode/BMode: BLOCK_K/2 (64 bytes packed, not 128 unpacked) - SMEM sizes: LOAD_BLOCK_M * BLOCK_K / 2 (packed byte count) - Token layouts: kHidden/2, kIntermediateHidden/2 (packed bytes) - TMA loads: BLOCK_K/2 inner dim, uint8_t, byte offsets k_block_idx*(BLOCK_K/2) - UMMA descriptors: BLOCK_K/2 template param, uint8_t dtype, UMMA_K/2 advance - L1 epilogue: dropped STSM, direct st.shared.u16 with packed nibbles, no swizzle (v1) - Pybind buffer sizes: hidden/2, intermediate_hidden/2 with packed tensor shapes - Host TMA descriptors: hidden/2 K-dims, block_k/2 inner, fp4_unpacked_smem=false - L1 output TMA: block_n/4 inner, no swizzle (CU_TENSOR_MAP_SWIZZLE_NONE)
This commit is contained in:
@@ -134,14 +134,14 @@ static void sm100_fp8_nvfp4_mega_moe(
|
||||
// NVFP4: kGranK=16 for group_size=16
|
||||
constexpr int kGranK = 16;
|
||||
|
||||
// Make tensormap — weight/activation TMA descriptors are the same as MXFP4
|
||||
// (E2M1 packed uint8 is the same format regardless of scale type)
|
||||
// L1 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8
|
||||
// Make tensormap — NVFP4 packed E2M1 (2 per byte), so K-dim is hidden/2 bytes
|
||||
// L1 activations: packed E2M1 (4 bits/element), K-dim = hidden/2
|
||||
const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts,
|
||||
hidden, config.num_max_pool_tokens,
|
||||
config.block_k, config.load_block_m,
|
||||
hidden / 2, config.num_max_pool_tokens,
|
||||
config.block_k / 2, config.load_block_m,
|
||||
static_cast<int>(l1_acts.stride(-2)),
|
||||
config.swizzle_acts_mode);
|
||||
config.swizzle_acts_mode / 2,
|
||||
0, false, false); // fp4_unpacked_smem=false
|
||||
// NVFP4 SF TMA: kGranK=16, so SF K-dim is hidden/16, packed as hidden/64 int32
|
||||
const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf,
|
||||
config.num_padded_sf_pool_tokens, hidden,
|
||||
@@ -156,18 +156,21 @@ static void sm100_fp8_nvfp4_mega_moe(
|
||||
intermediate_hidden * 2, hidden,
|
||||
config.block_n, kGranK,
|
||||
num_experts_per_rank, 0);
|
||||
// L1 output: unpacked E2M1 (1 byte/element), same dimensions as FP8
|
||||
// L1 output: packed E2M1, K-dim = intermediate_hidden/2, inner = block_n/4 bytes, no swizzle (v1)
|
||||
const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts,
|
||||
intermediate_hidden, config.num_max_pool_tokens,
|
||||
config.block_n / 2, config.store_block_m,
|
||||
intermediate_hidden / 2, config.num_max_pool_tokens,
|
||||
config.block_n / 4, config.store_block_m,
|
||||
static_cast<int>(l2_acts.stride(-2)),
|
||||
config.swizzle_acts_mode / 2);
|
||||
// L2 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8
|
||||
0, 0, // no swizzle
|
||||
false, // allow_tf32
|
||||
false); // fp4_unpacked_smem=false (packed!)
|
||||
// L2 activations: packed E2M1, K-dim = intermediate_hidden/2
|
||||
const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts,
|
||||
intermediate_hidden, config.num_max_pool_tokens,
|
||||
config.block_k, config.load_block_m,
|
||||
intermediate_hidden / 2, config.num_max_pool_tokens,
|
||||
config.block_k / 2, config.load_block_m,
|
||||
static_cast<int>(l2_acts.stride(-2)),
|
||||
config.swizzle_acts_mode);
|
||||
config.swizzle_acts_mode / 2,
|
||||
0, false, false); // fp4_unpacked_smem=false
|
||||
const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf,
|
||||
config.num_padded_sf_pool_tokens, intermediate_hidden,
|
||||
config.sf_block_m, kGranK,
|
||||
|
||||
Reference in New Issue
Block a user