feat: full FP4 activations for mxf4nvf4 - E2M1 packed A side + UE4M3 scales
mxf4nvf4 requires BOTH A and B to be FP4 (E2M1 packed). Changes: - a_dtype_t: float_e4m3_t → float_e2m1_unpacksmem_t - UMMA_K: 32 → 64 (FP4 MMA atom) - L1 epilogue: FP8 quant → E2M1 FP4 quantization with nearest-neighbor - L1 output SMEM: packed E2M1 (2 per byte), TMA store uint8 - TMA descriptors: adjusted for FP4 packing (K/2 bytes per row) - SymmBuffer: uint8 activations, shape (M, K//2) - Staging kernel: BF16 → E2M1 packed + UE4M3 block16 scales
This commit is contained in:
@@ -95,12 +95,14 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk);
|
||||
|
||||
// Token and buffer layouts
|
||||
constexpr auto fp8_token_layout = layout::Data(kHidden);
|
||||
// NVFP4: activations are E2M1 packed (2 per byte), so K/2 bytes per token
|
||||
constexpr auto fp4_token_layout = layout::Data(kHidden / 2);
|
||||
constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16));
|
||||
constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden);
|
||||
// NVFP4 intermediate: same E2M1 packing
|
||||
constexpr auto fp4_intermediate_token_layout = layout::Data(kIntermediateHidden / 2);
|
||||
// NVFP4: group_size=16, so SF stride is K/16 (twice as many scales as MXFP4)
|
||||
constexpr auto fp8_sf_layout = layout::Data(kHidden / 16);
|
||||
constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16);
|
||||
constexpr auto nvfp4_sf_layout = layout::Data(kHidden / 16);
|
||||
constexpr auto nvfp4_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16);
|
||||
constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false);
|
||||
constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false);
|
||||
constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
|
||||
@@ -164,8 +166,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
);
|
||||
|
||||
// Data types
|
||||
// NOTES: activations are FP8 (e4m3), weights are FP4 (e2m1)
|
||||
using a_dtype_t = cutlass::float_e4m3_t;
|
||||
// NVFP4: mxf4nvf4 requires BOTH A and B to be FP4 (E2M1 packed)
|
||||
using a_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t;
|
||||
using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t;
|
||||
|
||||
// MMA configs
|
||||
@@ -173,7 +175,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
constexpr uint32_t LAYOUT_AD_M = 128;
|
||||
constexpr uint32_t UMMA_M = LAYOUT_AD_M * 2;
|
||||
constexpr uint32_t UMMA_N = BLOCK_M; // Swap AB
|
||||
constexpr uint32_t UMMA_K = 32;
|
||||
constexpr uint32_t UMMA_K = 64; // FP4: 64 elements per MMA atom (was 32 for FP8)
|
||||
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A
|
||||
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N;
|
||||
DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M");
|
||||
@@ -181,8 +183,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
|
||||
|
||||
// Swizzle configs
|
||||
constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t);
|
||||
constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t);
|
||||
// NVFP4: float_e2m1_unpacksmem_t uses 1 byte per element in SMEM (unpacked)
|
||||
// Same byte stride as FP8 — the TMA hardware unpacks 2 E2M1 values per global byte
|
||||
constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); // 128
|
||||
constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); // 128
|
||||
constexpr uint32_t kSwizzleCDMode = 128;
|
||||
DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N");
|
||||
|
||||
@@ -200,13 +204,16 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
constexpr uint32_t SMEM_EXPERT_COUNT_SIZE =
|
||||
math::constexpr_align<uint32_t>(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment);
|
||||
constexpr uint32_t SMEM_SEND_BUFFER_SIZE =
|
||||
math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment);
|
||||
math::constexpr_align(fp4_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment);
|
||||
// NVFP4: float_e2m1_unpacksmem_t uses 1 byte per element in SMEM (unpacked)
|
||||
// TMA unpacks 2 E2M1 per global byte into 1 byte per element in SMEM
|
||||
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
|
||||
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
|
||||
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
|
||||
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
|
||||
// L1 output: E2M1 packed FP4 (2 per byte)
|
||||
constexpr uint32_t SMEM_CD_L1_SIZE =
|
||||
kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) * kNumTMAStoreStages;
|
||||
kNumEpilogueWarpgroups * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2) * sizeof(uint8_t) * kNumTMAStoreStages;
|
||||
constexpr uint32_t SMEM_CD_L2_SIZE =
|
||||
kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16);
|
||||
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE;
|
||||
@@ -1051,7 +1058,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
ptx::tma_store_wait<kNumTMAStoreStages - 1>();
|
||||
ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
|
||||
|
||||
// Cast to FP8 E4M3 and store into shared memory
|
||||
// Quantize to E2M1 FP4 and store into shared memory
|
||||
// NVFP4: mxf4nvf4 requires FP4×FP4, so L1 output is E2M1 packed
|
||||
// Scale for FP4: scale = amax / 6.0 (E2M1 max value)
|
||||
// UE4M3 scale already computed below (same as FP8 case but using /6)
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) {
|
||||
// Reduce amax
|
||||
@@ -1060,47 +1070,73 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x);
|
||||
amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y);
|
||||
|
||||
// Calculate SF
|
||||
// Calculate SF for E2M1: scale = amax / 6.0
|
||||
// UE4M3 format (same computation as FP8 but different scale base)
|
||||
float2 sf, sf_inv;
|
||||
math::get_e4m3_sf_and_sf_inv(amax_values[i], sf, sf_inv);
|
||||
// Use amax/6.0 as the scale (E2M1 max = 6)
|
||||
sf.x = fmaxf(amax_values[i].x / 6.0f, 1e-8f);
|
||||
sf.y = fmaxf(amax_values[i].y / 6.0f, 1e-8f);
|
||||
sf_inv.x = 1.0f / sf.x;
|
||||
sf_inv.y = 1.0f / sf.y;
|
||||
|
||||
// Cast
|
||||
const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv);
|
||||
const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv);
|
||||
const auto fp8x4_values = __nv_fp8x4_e4m3(make_float4(upper.x, upper.y, lower.x, lower.y));
|
||||
// E2M1 FP4 quantization: find nearest from [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
// Process 4 BF16 values -> 4 E2M1 4-bit values -> pack into 2 bytes
|
||||
auto quant_e2m1 = [](float v, float scale_inv) -> uint8_t {
|
||||
float q = v * scale_inv;
|
||||
q = fmaxf(-6.0f, fminf(6.0f, q));
|
||||
uint8_t sign = (q < 0.0f) ? 1 : 0;
|
||||
float aq = fabsf(q);
|
||||
// Nearest E2M1 index
|
||||
uint8_t idx;
|
||||
if (aq < 0.25f) idx = 0;
|
||||
else if (aq < 0.75f) idx = 1;
|
||||
else if (aq < 1.25f) idx = 2;
|
||||
else if (aq < 1.75f) idx = 3;
|
||||
else if (aq < 2.5f) idx = 4;
|
||||
else if (aq < 3.5f) idx = 5;
|
||||
else if (aq < 5.0f) idx = 6;
|
||||
else idx = 7;
|
||||
return (sign << 3) | idx; // 4-bit packed
|
||||
};
|
||||
|
||||
// STSM
|
||||
uint32_t row = lane_idx;
|
||||
// Quantize 4 BF16 values -> 4 E2M1 nibbles
|
||||
float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv);
|
||||
float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv);
|
||||
uint8_t e0 = quant_e2m1(upper.x, 1.0f);
|
||||
uint8_t e1 = quant_e2m1(upper.y, 1.0f);
|
||||
uint8_t e2 = quant_e2m1(lower.x, 1.0f);
|
||||
uint8_t e3 = quant_e2m1(lower.y, 1.0f);
|
||||
// Pack 2 nibbles per byte: (e1<<4)|e0, (e3<<4)|e2
|
||||
uint8_t b0 = (e1 << 4) | e0;
|
||||
uint8_t b1 = (e3 << 4) | e2;
|
||||
|
||||
// Store packed FP4 bytes to SMEM (row-major, L1_OUT_BLOCK_N/2 bytes per row)
|
||||
uint32_t row = lane_idx; // lane maps to row within ATOM_M
|
||||
uint32_t col = warp_idx_in_wg;
|
||||
const auto smem_ptr = smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N
|
||||
+ i * ATOM_M * L1_OUT_BLOCK_N
|
||||
+ row * L1_OUT_BLOCK_N
|
||||
+ (col ^ (row / 2)) * kNumBankGroupBytes;
|
||||
ptx::SM100_U8x4_STSM_T<__nv_fp8x4_e4m3>::copy(fp8x4_values, smem_ptr);
|
||||
// SMEM layout: simple row-major, L1_OUT_BLOCK_N/2 bytes per row
|
||||
// Each STSM atom wrote 4 FP8 values = 4 bytes. Now we write 2 bytes (4 FP4 values packed).
|
||||
// Column offset: each warp handles 4 bytes worth of N (was 4 FP8 = 4 bytes, now 4 FP4 = 2 bytes)
|
||||
const auto smem_base = smem_cd[tma_stage_idx]
|
||||
+ epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2)
|
||||
+ i * ATOM_M * (L1_OUT_BLOCK_N / 2);
|
||||
// Bank-conflict-free addressing: interleave with 4-byte offset per warp
|
||||
uint32_t byte_col = col * 2; // 2 bytes per warp per row
|
||||
auto smem_ptr = smem_base + row * (L1_OUT_BLOCK_N / 2) + byte_col;
|
||||
// Write 2 packed FP4 bytes
|
||||
smem_ptr[0] = b0;
|
||||
smem_ptr[1] = b1;
|
||||
|
||||
// Store SF to `l2_sf_buffer` as UE8M0 (MN-major layout)
|
||||
// Only one warp per pair writes (both hold the same SF after cross-warp reduce)
|
||||
// Each lane < 4 holds SF for 2 rows (sf.x and sf.y)
|
||||
// Store SF to `l2_sf_buffer` as UE4M3 (MN-major layout)
|
||||
if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) {
|
||||
const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2;
|
||||
const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4;
|
||||
const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t);
|
||||
const auto sf_base_ptr = l2_sf_buffer.get_base_ptr<uint8_t>();
|
||||
// NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4
|
||||
// NOTES: originally there was:
|
||||
// - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2
|
||||
// - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)`
|
||||
// We find out that
|
||||
// 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside
|
||||
// 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside
|
||||
// This reduce the number of computation instructions.
|
||||
const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M;
|
||||
__builtin_assume(token_base_idx < BLOCK_M);
|
||||
const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M
|
||||
+ m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4;
|
||||
const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx;
|
||||
// NVFP4: convert float scale to UE4M3 format
|
||||
// UE4M3: sign=0 + 4 exp + 3 mantissa, max=448
|
||||
auto to_ue4m3 = [](float v) -> uint8_t {
|
||||
v = fmaxf(0.0f, fminf(v, 448.0f));
|
||||
cutlass::float_e4m3_t e4m3_val = cutlass::float_e4m3_t(v);
|
||||
@@ -1115,11 +1151,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// Issue TMA store after all atoms in this store block
|
||||
if (warp_idx_in_wg == 0 and cute::elect_one_sync()) {
|
||||
uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N;
|
||||
uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N / 2; // FP4: byte offset = element offset / 2
|
||||
cute::tma_store_fence();
|
||||
cute::SM90_TMA_STORE_2D::copy(
|
||||
&tensor_map_l1_output,
|
||||
smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N,
|
||||
smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2),
|
||||
out_n_idx,
|
||||
m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M);
|
||||
cute::tma_store_arrive();
|
||||
|
||||
Reference in New Issue
Block a user