fix: packed FP4 for mxf4nvf4 — correct SMEM layout, UMMA descriptors, L1 epilogue
Key changes: - a_dtype_t/b_dtype_t: float_e2m1_t (packed 4-bit) with sizeof_bits_v==4 assert - kSwizzleAMode/BMode: BLOCK_K/2 (64 bytes packed, not 128 unpacked) - SMEM sizes: LOAD_BLOCK_M * BLOCK_K / 2 (packed byte count) - Token layouts: kHidden/2, kIntermediateHidden/2 (packed bytes) - TMA loads: BLOCK_K/2 inner dim, uint8_t, byte offsets k_block_idx*(BLOCK_K/2) - UMMA descriptors: BLOCK_K/2 template param, uint8_t dtype, UMMA_K/2 advance - L1 epilogue: dropped STSM, direct st.shared.u16 with packed nibbles, no swizzle (v1) - Pybind buffer sizes: hidden/2, intermediate_hidden/2 with packed tensor shapes - Host TMA descriptors: hidden/2 K-dims, block_k/2 inner, fp4_unpacked_smem=false - L1 output TMA: block_n/4 inner, no swizzle (CU_TENSOR_MAP_SWIZZLE_NONE)
This commit is contained in:
@@ -28,9 +28,9 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
|
||||
// NVFP4 layouts: E2M1 packed (2 per byte), so token layout is K/2 bytes
|
||||
// group_size=16, so SF stride is K/16 (twice as many as MXFP4)
|
||||
const auto fp4_token_layout = layout::Data(hidden);
|
||||
const auto fp4_token_layout = layout::Data(hidden / 2);
|
||||
const auto bf16_token_layout = layout::Data(hidden * 2);
|
||||
const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden);
|
||||
const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden / 2);
|
||||
const auto nvfp4_sf_layout = layout::Data(hidden / 16);
|
||||
const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16);
|
||||
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
|
||||
@@ -95,7 +95,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
// NVFP4: E2M1 packed activations (2 per byte), K/2 bytes per token
|
||||
auto x = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
|
||||
{num_max_tokens_per_rank, hidden},
|
||||
{num_max_tokens_per_rank, hidden / 2}, // packed: hidden elements = hidden/2 bytes
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
|
||||
// NVFP4 SF: K/16 bytes per token, packed as K/64 int32
|
||||
auto x_sf = torch::from_blob(
|
||||
@@ -110,10 +110,10 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_weights_buffer.base)),
|
||||
{num_max_tokens_per_rank, num_topk},
|
||||
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
|
||||
// NVFP4: L1 output acts are E2M1 packed, intermediate_hidden/2 bytes
|
||||
// NVFP4: L1 output acts are E2M1 packed, hidden/2 bytes
|
||||
auto l1_acts = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
|
||||
{num_max_pool_tokens, hidden},
|
||||
{num_max_pool_tokens, hidden / 2}, // packed: hidden elements = hidden/2 bytes
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
|
||||
// NVFP4 L1 SF: M-major, K/64 int32
|
||||
auto l1_acts_sf = torch::from_blob(
|
||||
@@ -124,7 +124,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
// NVFP4: L2 acts are E2M1 packed, intermediate_hidden/2 bytes
|
||||
auto l2_acts = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
|
||||
{num_max_pool_tokens, intermediate_hidden},
|
||||
{num_max_pool_tokens, intermediate_hidden / 2}, // packed: elements/2 bytes
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
|
||||
// NVFP4 L2 SF: M-major, K/64 int32
|
||||
auto l2_acts_sf = torch::from_blob(
|
||||
|
||||
@@ -134,14 +134,14 @@ static void sm100_fp8_nvfp4_mega_moe(
|
||||
// NVFP4: kGranK=16 for group_size=16
|
||||
constexpr int kGranK = 16;
|
||||
|
||||
// Make tensormap — weight/activation TMA descriptors are the same as MXFP4
|
||||
// (E2M1 packed uint8 is the same format regardless of scale type)
|
||||
// L1 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8
|
||||
// Make tensormap — NVFP4 packed E2M1 (2 per byte), so K-dim is hidden/2 bytes
|
||||
// L1 activations: packed E2M1 (4 bits/element), K-dim = hidden/2
|
||||
const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts,
|
||||
hidden, config.num_max_pool_tokens,
|
||||
config.block_k, config.load_block_m,
|
||||
hidden / 2, config.num_max_pool_tokens,
|
||||
config.block_k / 2, config.load_block_m,
|
||||
static_cast<int>(l1_acts.stride(-2)),
|
||||
config.swizzle_acts_mode);
|
||||
config.swizzle_acts_mode / 2,
|
||||
0, false, false); // fp4_unpacked_smem=false
|
||||
// NVFP4 SF TMA: kGranK=16, so SF K-dim is hidden/16, packed as hidden/64 int32
|
||||
const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf,
|
||||
config.num_padded_sf_pool_tokens, hidden,
|
||||
@@ -156,18 +156,21 @@ static void sm100_fp8_nvfp4_mega_moe(
|
||||
intermediate_hidden * 2, hidden,
|
||||
config.block_n, kGranK,
|
||||
num_experts_per_rank, 0);
|
||||
// L1 output: unpacked E2M1 (1 byte/element), same dimensions as FP8
|
||||
// L1 output: packed E2M1, K-dim = intermediate_hidden/2, inner = block_n/4 bytes, no swizzle (v1)
|
||||
const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts,
|
||||
intermediate_hidden, config.num_max_pool_tokens,
|
||||
config.block_n / 2, config.store_block_m,
|
||||
intermediate_hidden / 2, config.num_max_pool_tokens,
|
||||
config.block_n / 4, config.store_block_m,
|
||||
static_cast<int>(l2_acts.stride(-2)),
|
||||
config.swizzle_acts_mode / 2);
|
||||
// L2 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8
|
||||
0, 0, // no swizzle
|
||||
false, // allow_tf32
|
||||
false); // fp4_unpacked_smem=false (packed!)
|
||||
// L2 activations: packed E2M1, K-dim = intermediate_hidden/2
|
||||
const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts,
|
||||
intermediate_hidden, config.num_max_pool_tokens,
|
||||
config.block_k, config.load_block_m,
|
||||
intermediate_hidden / 2, config.num_max_pool_tokens,
|
||||
config.block_k / 2, config.load_block_m,
|
||||
static_cast<int>(l2_acts.stride(-2)),
|
||||
config.swizzle_acts_mode);
|
||||
config.swizzle_acts_mode / 2,
|
||||
0, false, false); // fp4_unpacked_smem=false
|
||||
const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf,
|
||||
config.num_padded_sf_pool_tokens, intermediate_hidden,
|
||||
config.sf_block_m, kGranK,
|
||||
|
||||
@@ -96,10 +96,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// Token and buffer layouts
|
||||
// NVFP4: activations are E2M1 packed (2 per byte), so K/2 bytes per token
|
||||
constexpr auto fp4_token_layout = layout::Data(kHidden);
|
||||
constexpr auto fp4_token_layout = layout::Data(kHidden / 2); // packed: 2 E2M1 per byte
|
||||
constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16));
|
||||
// NVFP4 intermediate: same E2M1 packing
|
||||
constexpr auto fp4_intermediate_token_layout = layout::Data(kIntermediateHidden);
|
||||
constexpr auto fp4_intermediate_token_layout = layout::Data(kIntermediateHidden / 2); // packed
|
||||
// NVFP4: group_size=16, so SF stride is K/16 (twice as many scales as MXFP4)
|
||||
constexpr auto nvfp4_sf_layout = layout::Data(kHidden / 16);
|
||||
constexpr auto nvfp4_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16);
|
||||
@@ -167,13 +167,15 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// Data types
|
||||
// NVFP4: mxf4nvf4 requires BOTH A and B to be FP4 (E2M1 packed)
|
||||
using a_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t;
|
||||
using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t;
|
||||
// Verify SMEM element sizes: unpacksmem = 1 byte/element (not 0.5)
|
||||
// The TMA load unpacks 2 E2M1/global-byte → 2 SMEM bytes.
|
||||
// UMMA reads unpacked SMEM and packs internally for mxf4nvf4.
|
||||
static_assert(sizeof(a_dtype_t) == 1, "a_dtype_t (E2M1 unpacked SMEM) must be 1 byte/element");
|
||||
static_assert(sizeof(b_dtype_t) == 1, "b_dtype_t (E2M1 unpacked SMEM) must be 1 byte/element");
|
||||
using a_dtype_t = cutlass::float_e2m1_t;
|
||||
using b_dtype_t = cutlass::float_e2m1_t;
|
||||
// mxf4nvf4 reads packed FP4 from SMEM (2 values per byte), NOT unpacked.
|
||||
// _unpacksmem_t was for mxf8f6f4 which reads FP4 as FP8 (1 byte/element).
|
||||
// For mxf4nvf4: sizeof_bits = 4, SMEM stride = BLOCK_K/2 bytes, UMMA_K = 64.
|
||||
static_assert(cutlass::sizeof_bits_v<a_dtype_t> == 4,
|
||||
"mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM");
|
||||
static_assert(cutlass::sizeof_bits_v<b_dtype_t> == 4,
|
||||
"mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM");
|
||||
|
||||
// MMA configs
|
||||
// NOTES: always swap A/B, 2-CTA MMA, and matrices are K-major
|
||||
@@ -190,8 +192,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
// Swizzle configs
|
||||
// NVFP4: float_e2m1_unpacksmem_t uses 1 byte per element in SMEM (unpacked)
|
||||
// Same byte stride as FP8 — the TMA hardware unpacks 2 E2M1 values per global byte
|
||||
constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); // 128
|
||||
constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); // 128
|
||||
// Packed FP4: BLOCK_K elements = BLOCK_K/2 bytes per row (2 values per byte)
|
||||
constexpr uint32_t kSwizzleAMode = BLOCK_K / 2; // 64 bytes
|
||||
constexpr uint32_t kSwizzleBMode = BLOCK_K / 2; // 64 bytes
|
||||
constexpr uint32_t kSwizzleCDMode = 128;
|
||||
DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N");
|
||||
|
||||
@@ -212,15 +215,14 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
math::constexpr_align(fp4_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment);
|
||||
// NVFP4: float_e2m1_unpacksmem_t uses 1 byte per element in SMEM (unpacked)
|
||||
// TMA unpacks 2 E2M1 per global byte into 1 byte per element in SMEM
|
||||
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
|
||||
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
|
||||
// Packed FP4: 4 bits/element → LOAD_BLOCK_M * BLOCK_K / 2 bytes
|
||||
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K / 2;
|
||||
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K / 2;
|
||||
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
|
||||
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
|
||||
// L1 output: unpacked E2M1 in SMEM (1 byte/element, same as FP8 layout)
|
||||
// TMA store writes to global as unpacked; L2 TMA load reads unpacked into SMEM
|
||||
// This avoids the packed/unpacked mismatch at the L1→L2 boundary
|
||||
// L1 output: packed E2M1 in SMEM (2 per byte), L1_OUT_BLOCK_N/2 bytes per row
|
||||
constexpr uint32_t SMEM_CD_L1_SIZE =
|
||||
kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(uint8_t) * kNumTMAStoreStages;
|
||||
kNumEpilogueWarpgroups * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2) * sizeof(uint8_t) * kNumTMAStoreStages;
|
||||
constexpr uint32_t SMEM_CD_L2_SIZE =
|
||||
kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16);
|
||||
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE;
|
||||
@@ -573,7 +575,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
pull_buffer.get_base_ptr(),
|
||||
sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(),
|
||||
current_rank_in_expert_idx),
|
||||
pull_mbarrier, kHidden); // NVFP4: unpacked E2M1, same byte count as FP8
|
||||
pull_mbarrier, kHidden / 2); // NVFP4: packed E2M1, half the bytes
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
@@ -605,7 +607,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
*l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr<float>() = weight;
|
||||
|
||||
// Wait for TMA token load to complete
|
||||
ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); // NVFP4: unpacked E2M1, same byte count as FP8
|
||||
ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden / 2); // NVFP4: packed E2M1, half the bytes
|
||||
ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase);
|
||||
|
||||
// Store token to local L1 buffer via TMA
|
||||
@@ -728,7 +730,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// Compute token offset from pool block index
|
||||
uint32_t m_idx = pool_block_idx * BLOCK_M;
|
||||
uint32_t k_idx = k_block_idx * BLOCK_K;
|
||||
uint32_t k_idx = k_block_idx * (BLOCK_K / 2); // packed FP4: byte offset
|
||||
uint32_t sfa_m_idx = pool_block_idx * SF_BLOCK_M;
|
||||
uint32_t sfa_k_idx = k_block_idx;
|
||||
|
||||
@@ -738,7 +740,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// TMA copy tokens and SFA, then arrive at full barrier
|
||||
if (cute::elect_one_sync()) {
|
||||
tma::copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(
|
||||
tma::copy<BLOCK_K / 2, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(
|
||||
tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2);
|
||||
tma::copy<SF_BLOCK_M, 1, 0>(
|
||||
tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2);
|
||||
@@ -775,13 +777,14 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// Compute weight offset
|
||||
uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N;
|
||||
uint32_t k_idx = k_block_idx * BLOCK_K;
|
||||
uint32_t k_idx = k_block_idx * (BLOCK_K / 2); // packed FP4: byte offset
|
||||
uint32_t sfb_n_idx = n_block_idx * BLOCK_N;
|
||||
uint32_t sfb_k_idx = local_expert_idx * shape_sfb_k + k_block_idx;
|
||||
|
||||
// TMA copy weights with SF
|
||||
if (cute::elect_one_sync()) {
|
||||
tma::copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(
|
||||
// NVFP4: weights are packed E2M1, BLOCK_K elements = BLOCK_K/2 bytes
|
||||
tma::copy<BLOCK_K / 2, LOAD_BLOCK_N, kSwizzleBMode, uint8_t>(
|
||||
tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2);
|
||||
tma::copy<BLOCK_N, 1, 0>(
|
||||
tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2);
|
||||
@@ -810,8 +813,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
auto sf_desc = mma::sm100::make_sf_desc(nullptr);
|
||||
|
||||
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
||||
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
|
||||
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
|
||||
// NVFP4: UMMA descriptors use packed byte dimensions (BLOCK_K/2, uint8_t)
|
||||
// because sizeof(float_e2m1_t)=1 but real stride is BLOCK_K/2 bytes per K-row
|
||||
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_M, BLOCK_K / 2, kSwizzleAMode, false, uint8_t>(smem_a[0], 0, 0);
|
||||
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_N, BLOCK_K / 2, kSwizzleBMode, false, uint8_t>(smem_b[0], 0, 0);
|
||||
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
||||
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
||||
|
||||
@@ -884,9 +889,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
const auto runtime_instr_desc =
|
||||
mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k);
|
||||
a_desc.lo = mma::sm100::advance_umma_desc_lo<
|
||||
cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K);
|
||||
cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(a_desc_base_lo, 0, k * (UMMA_K / 2));
|
||||
b_desc.lo = mma::sm100::advance_umma_desc_lo<
|
||||
cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
|
||||
cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, uint8_t>(b_desc_base_lo, 0, k * (UMMA_K / 2));
|
||||
// NVFP4: use mxf4nvf4 instruction with UE4M3 scales
|
||||
ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma(
|
||||
b_desc, a_desc, accum_stage_idx * UMMA_N,
|
||||
@@ -1066,65 +1071,66 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
ptx::tma_store_wait<kNumTMAStoreStages - 1>();
|
||||
ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
|
||||
|
||||
// Quantize to E2M1 FP4 and STSM into shared memory (unpacked, 1 byte/element)
|
||||
// NVFP4: mxf4nvf4 requires FP4×FP4, so L1 output is E2M1
|
||||
// SMEM layout: unpacked (1 byte per E2M1 element), same byte count as FP8
|
||||
// 4 BF16 → 4 E2M1 bytes → 1 STSM uint32 write, with XOR swizzle
|
||||
// Quantize to E2M1 FP4 and write packed SMEM (2 values per byte)
|
||||
// NVFP4: mxf4nvf4 requires FP4×FP4, so L1 output is packed E2M1
|
||||
// SMEM layout: packed, L1_OUT_BLOCK_N/2 bytes per row, no swizzle (v1)
|
||||
auto quant_e2m1_byte = [](float v) -> uint8_t {
|
||||
v = fmaxf(-6.0f, fminf(6.0f, v));
|
||||
uint8_t sign = (v < 0.0f) ? 1 : 0;
|
||||
float aq = fabsf(v);
|
||||
uint8_t idx;
|
||||
if (aq < 0.25f) idx = 0;
|
||||
else if (aq < 0.75f) idx = 1;
|
||||
else if (aq < 1.25f) idx = 2;
|
||||
else if (aq < 1.75f) idx = 3;
|
||||
else if (aq < 2.5f) idx = 4;
|
||||
else if (aq < 3.5f) idx = 5;
|
||||
else if (aq < 5.0f) idx = 6;
|
||||
else idx = 7;
|
||||
return uint8_t((sign << 3) | idx);
|
||||
};
|
||||
|
||||
// Lane mapping: row_in_atom = lane_idx/4, col_pair = lane_idx%4
|
||||
// Each warp owns L1_OUT_BLOCK_N/8 bytes per row (= 16 N-elements)
|
||||
constexpr uint32_t kWarpBytesPerRow = L1_OUT_BLOCK_N / 8;
|
||||
const uint32_t row_in_atom = lane_idx >> 2;
|
||||
const uint32_t col_pair = lane_idx & 3;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) {
|
||||
// Reduce amax
|
||||
// Reduce amax across warp pair
|
||||
const float2 wp_amax =
|
||||
smem_amax_reduction[(epilogue_warp_idx ^ 1) * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx % 4];
|
||||
amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x);
|
||||
amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y);
|
||||
|
||||
// Calculate SF for E2M1: scale = amax / 6.0 (E2M1 max = 6)
|
||||
// E2M1: scale = amax / 6.0 (E2M1 max magnitude)
|
||||
float2 sf, sf_inv;
|
||||
sf.x = fmaxf(amax_values[i].x / 6.0f, 1e-8f);
|
||||
sf.y = fmaxf(amax_values[i].y / 6.0f, 1e-8f);
|
||||
sf_inv.x = 1.0f / sf.x;
|
||||
sf_inv.y = 1.0f / sf.y;
|
||||
|
||||
// E2M1 FP4 quantization: find nearest from [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
// Store as 4-bit in low nibble of each byte (unpacked SMEM: 1 byte/element)
|
||||
auto quant_e2m1_byte = [](float v) -> uint8_t {
|
||||
v = fmaxf(-6.0f, fminf(6.0f, v));
|
||||
uint8_t sign = (v < 0.0f) ? 1 : 0;
|
||||
float aq = fabsf(v);
|
||||
uint8_t idx;
|
||||
if (aq < 0.25f) idx = 0;
|
||||
else if (aq < 0.75f) idx = 1;
|
||||
else if (aq < 1.25f) idx = 2;
|
||||
else if (aq < 1.75f) idx = 3;
|
||||
else if (aq < 2.5f) idx = 4;
|
||||
else if (aq < 3.5f) idx = 5;
|
||||
else if (aq < 5.0f) idx = 6;
|
||||
else idx = 7;
|
||||
return (sign << 3) | idx; // 4-bit E2M1 in low nibble
|
||||
};
|
||||
|
||||
// Quantize 4 BF16 values → 4 E2M1 bytes → pack into uint32 for STSM
|
||||
// 4 SwiGLU floats → 4 E2M1 nibbles → 2 bytes (uint16) for this lane
|
||||
const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv);
|
||||
const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv);
|
||||
uint8_t e0 = quant_e2m1_byte(upper.x);
|
||||
uint8_t e1 = quant_e2m1_byte(upper.y);
|
||||
uint8_t e2 = quant_e2m1_byte(lower.x);
|
||||
uint8_t e3 = quant_e2m1_byte(lower.y);
|
||||
// Pack 4 bytes into uint32 for one STSM atom
|
||||
uint32_t packed = (uint32_t)e0 | ((uint32_t)e1 << 8)
|
||||
| ((uint32_t)e2 << 16) | ((uint32_t)e3 << 24);
|
||||
const uint8_t e0 = quant_e2m1_byte(upper.x);
|
||||
const uint8_t e1 = quant_e2m1_byte(upper.y);
|
||||
const uint8_t e2 = quant_e2m1_byte(lower.x);
|
||||
const uint8_t e3 = quant_e2m1_byte(lower.y);
|
||||
const uint16_t packed16 =
|
||||
uint16_t(((e1 & 0xF) << 4) | (e0 & 0xF))
|
||||
| uint16_t((((e3 & 0xF) << 4) | (e2 & 0xF)) << 8);
|
||||
|
||||
// STSM with XOR swizzle (same layout as FP8, 1 byte/element)
|
||||
uint32_t row = lane_idx;
|
||||
uint32_t col = warp_idx_in_wg;
|
||||
const auto smem_ptr = smem_cd[tma_stage_idx]
|
||||
+ epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N
|
||||
+ i * ATOM_M * L1_OUT_BLOCK_N
|
||||
+ row * L1_OUT_BLOCK_N
|
||||
+ (col ^ (row / 2)) * kNumBankGroupBytes;
|
||||
ptx::SM100_U32x1_STSM_T<uint32_t>::copy(packed, smem_ptr);
|
||||
// Direct shared store (st.shared.u16) — no STSM, no swizzle
|
||||
uint8_t* smem_byte_ptr = smem_cd[tma_stage_idx]
|
||||
+ epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2)
|
||||
+ (i * ATOM_M + row_in_atom) * (L1_OUT_BLOCK_N / 2)
|
||||
+ warp_idx_in_wg * kWarpBytesPerRow
|
||||
+ col_pair * 2;
|
||||
*reinterpret_cast<uint16_t*>(smem_byte_ptr) = packed16;
|
||||
|
||||
// Store SF to `l2_sf_buffer` as UE4M3 (MN-major layout)
|
||||
// SF store to l2_sf_buffer as UE4M3 (MN-major layout)
|
||||
if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) {
|
||||
const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2;
|
||||
const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4;
|
||||
@@ -1134,26 +1140,26 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
__builtin_assume(token_base_idx < BLOCK_M);
|
||||
const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M
|
||||
+ m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4;
|
||||
const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx;
|
||||
const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * uint32_t(sizeof(uint32_t)) + byte_idx;
|
||||
auto to_ue4m3 = [](float v) -> uint8_t {
|
||||
v = fmaxf(0.0f, fminf(v, 448.0f));
|
||||
cutlass::float_e4m3_t e4m3_val = cutlass::float_e4m3_t(v);
|
||||
return reinterpret_cast<uint8_t&>(e4m3_val) & 0x7F;
|
||||
cutlass::float_e4m3_t e = cutlass::float_e4m3_t(v);
|
||||
return reinterpret_cast<uint8_t&>(e) & 0x7F;
|
||||
};
|
||||
sf_base_ptr[sf_addr] = to_ue4m3(sf.x);
|
||||
sf_base_ptr[sf_addr + 4 * static_cast<uint32_t>(sizeof(uint32_t))] = to_ue4m3(sf.y);
|
||||
sf_base_ptr[sf_addr + 4 * uint32_t(sizeof(uint32_t))] = to_ue4m3(sf.y);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
|
||||
|
||||
// Issue TMA store after all atoms in this store block
|
||||
// TMA store SMEM → global (packed FP4 byte offsets)
|
||||
if (warp_idx_in_wg == 0 and cute::elect_one_sync()) {
|
||||
uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; // unpacked: 1 byte/element
|
||||
const uint32_t out_n_idx = n_block_idx * (L1_OUT_BLOCK_N / 2);
|
||||
cute::tma_store_fence();
|
||||
cute::SM90_TMA_STORE_2D::copy(
|
||||
&tensor_map_l1_output,
|
||||
smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N),
|
||||
smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2),
|
||||
out_n_idx,
|
||||
m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M);
|
||||
cute::tma_store_arrive();
|
||||
|
||||
Reference in New Issue
Block a user