fix: packed FP4 for mxf4nvf4 — correct SMEM layout, UMMA descriptors, L1 epilogue
Key changes: - a_dtype_t/b_dtype_t: float_e2m1_t (packed 4-bit) with sizeof_bits_v==4 assert - kSwizzleAMode/BMode: BLOCK_K/2 (64 bytes packed, not 128 unpacked) - SMEM sizes: LOAD_BLOCK_M * BLOCK_K / 2 (packed byte count) - Token layouts: kHidden/2, kIntermediateHidden/2 (packed bytes) - TMA loads: BLOCK_K/2 inner dim, uint8_t, byte offsets k_block_idx*(BLOCK_K/2) - UMMA descriptors: BLOCK_K/2 template param, uint8_t dtype, UMMA_K/2 advance - L1 epilogue: dropped STSM, direct st.shared.u16 with packed nibbles, no swizzle (v1) - Pybind buffer sizes: hidden/2, intermediate_hidden/2 with packed tensor shapes - Host TMA descriptors: hidden/2 K-dims, block_k/2 inner, fp4_unpacked_smem=false - L1 output TMA: block_n/4 inner, no swizzle (CU_TENSOR_MAP_SWIZZLE_NONE)
This commit is contained in:
@@ -28,9 +28,9 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
|
||||
// NVFP4 layouts: E2M1 packed (2 per byte), so token layout is K/2 bytes
|
||||
// group_size=16, so SF stride is K/16 (twice as many as MXFP4)
|
||||
const auto fp4_token_layout = layout::Data(hidden);
|
||||
const auto fp4_token_layout = layout::Data(hidden / 2);
|
||||
const auto bf16_token_layout = layout::Data(hidden * 2);
|
||||
const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden);
|
||||
const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden / 2);
|
||||
const auto nvfp4_sf_layout = layout::Data(hidden / 16);
|
||||
const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16);
|
||||
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
|
||||
@@ -95,7 +95,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
// NVFP4: E2M1 packed activations (2 per byte), K/2 bytes per token
|
||||
auto x = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
|
||||
{num_max_tokens_per_rank, hidden},
|
||||
{num_max_tokens_per_rank, hidden / 2}, // packed: hidden elements = hidden/2 bytes
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
|
||||
// NVFP4 SF: K/16 bytes per token, packed as K/64 int32
|
||||
auto x_sf = torch::from_blob(
|
||||
@@ -110,10 +110,10 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_weights_buffer.base)),
|
||||
{num_max_tokens_per_rank, num_topk},
|
||||
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
|
||||
// NVFP4: L1 output acts are E2M1 packed, intermediate_hidden/2 bytes
|
||||
// NVFP4: L1 output acts are E2M1 packed, hidden/2 bytes
|
||||
auto l1_acts = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
|
||||
{num_max_pool_tokens, hidden},
|
||||
{num_max_pool_tokens, hidden / 2}, // packed: hidden elements = hidden/2 bytes
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
|
||||
// NVFP4 L1 SF: M-major, K/64 int32
|
||||
auto l1_acts_sf = torch::from_blob(
|
||||
@@ -124,7 +124,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
// NVFP4: L2 acts are E2M1 packed, intermediate_hidden/2 bytes
|
||||
auto l2_acts = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
|
||||
{num_max_pool_tokens, intermediate_hidden},
|
||||
{num_max_pool_tokens, intermediate_hidden / 2}, // packed: elements/2 bytes
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
|
||||
// NVFP4 L2 SF: M-major, K/64 int32
|
||||
auto l2_acts_sf = torch::from_blob(
|
||||
|
||||
@@ -134,14 +134,14 @@ static void sm100_fp8_nvfp4_mega_moe(
|
||||
// NVFP4: kGranK=16 for group_size=16
|
||||
constexpr int kGranK = 16;
|
||||
|
||||
// Make tensormap — weight/activation TMA descriptors are the same as MXFP4
|
||||
// (E2M1 packed uint8 is the same format regardless of scale type)
|
||||
// L1 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8
|
||||
// Make tensormap — NVFP4 packed E2M1 (2 per byte), so K-dim is hidden/2 bytes
|
||||
// L1 activations: packed E2M1 (4 bits/element), K-dim = hidden/2
|
||||
const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts,
|
||||
hidden, config.num_max_pool_tokens,
|
||||
config.block_k, config.load_block_m,
|
||||
hidden / 2, config.num_max_pool_tokens,
|
||||
config.block_k / 2, config.load_block_m,
|
||||
static_cast<int>(l1_acts.stride(-2)),
|
||||
config.swizzle_acts_mode);
|
||||
config.swizzle_acts_mode / 2,
|
||||
0, false, false); // fp4_unpacked_smem=false
|
||||
// NVFP4 SF TMA: kGranK=16, so SF K-dim is hidden/16, packed as hidden/64 int32
|
||||
const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf,
|
||||
config.num_padded_sf_pool_tokens, hidden,
|
||||
@@ -156,18 +156,21 @@ static void sm100_fp8_nvfp4_mega_moe(
|
||||
intermediate_hidden * 2, hidden,
|
||||
config.block_n, kGranK,
|
||||
num_experts_per_rank, 0);
|
||||
// L1 output: unpacked E2M1 (1 byte/element), same dimensions as FP8
|
||||
// L1 output: packed E2M1, K-dim = intermediate_hidden/2, inner = block_n/4 bytes, no swizzle (v1)
|
||||
const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts,
|
||||
intermediate_hidden, config.num_max_pool_tokens,
|
||||
config.block_n / 2, config.store_block_m,
|
||||
intermediate_hidden / 2, config.num_max_pool_tokens,
|
||||
config.block_n / 4, config.store_block_m,
|
||||
static_cast<int>(l2_acts.stride(-2)),
|
||||
config.swizzle_acts_mode / 2);
|
||||
// L2 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8
|
||||
0, 0, // no swizzle
|
||||
false, // allow_tf32
|
||||
false); // fp4_unpacked_smem=false (packed!)
|
||||
// L2 activations: packed E2M1, K-dim = intermediate_hidden/2
|
||||
const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts,
|
||||
intermediate_hidden, config.num_max_pool_tokens,
|
||||
config.block_k, config.load_block_m,
|
||||
intermediate_hidden / 2, config.num_max_pool_tokens,
|
||||
config.block_k / 2, config.load_block_m,
|
||||
static_cast<int>(l2_acts.stride(-2)),
|
||||
config.swizzle_acts_mode);
|
||||
config.swizzle_acts_mode / 2,
|
||||
0, false, false); // fp4_unpacked_smem=false
|
||||
const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf,
|
||||
config.num_padded_sf_pool_tokens, intermediate_hidden,
|
||||
config.sf_block_m, kGranK,
|
||||
|
||||
Reference in New Issue
Block a user