diff --git a/csrc/apis/mega_nvfp4.hpp b/csrc/apis/mega_nvfp4.hpp index d47bfb0..fd9fff0 100644 --- a/csrc/apis/mega_nvfp4.hpp +++ b/csrc/apis/mega_nvfp4.hpp @@ -26,10 +26,11 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // Workspace bytes const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk); - // NVFP4 layouts: group_size=16, so SF stride is K/16 (twice as many as MXFP4) - const auto fp8_token_layout = layout::Data(hidden); + // NVFP4 layouts: E2M1 packed (2 per byte), so token layout is K/2 bytes + // group_size=16, so SF stride is K/16 (twice as many as MXFP4) + const auto fp4_token_layout = layout::Data(hidden / 2); const auto bf16_token_layout = layout::Data(hidden * 2); - const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden); + const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden / 2); const auto nvfp4_sf_layout = layout::Data(hidden / 16); const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16); const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false); @@ -38,7 +39,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // Input buffers const auto input_token_buffer = layout::Buffer( - fp8_token_layout, 1, num_max_tokens_per_rank, + fp4_token_layout, 1, num_max_tokens_per_rank, workspace.get_end_ptr()); const auto input_sf_buffer = layout::Buffer( nvfp4_sf_layout, 1, num_max_tokens_per_rank, @@ -62,7 +63,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // L1 input buffer const auto l1_token_buffer = layout::Buffer( - fp8_token_layout, 1, num_max_pool_tokens, + fp4_token_layout, 1, num_max_pool_tokens, input_topk_weights_buffer.get_end_ptr()); const auto l1_sf_buffer = layout::Buffer( nvfp4_sf_layout, 1, num_max_padded_sf_pool_tokens, @@ -73,7 +74,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // L2 input buffer const auto l2_token_buffer = layout::Buffer( - fp8_intermediate_token_layout, 1, num_max_pool_tokens, + fp4_intermediate_token_layout, 1, num_max_pool_tokens, l1_topk_weights_buffer.get_end_ptr()); const auto l2_sf_buffer = layout::Buffer( nvfp4_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens, @@ -91,10 +92,11 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // Slice function auto slice_input_buffers = [=](const torch::Tensor& buffer) { + // NVFP4: E2M1 packed activations (2 per byte), K/2 bytes per token auto x = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_token_buffer.base)), - {num_max_tokens_per_rank, hidden}, - torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + {num_max_tokens_per_rank, hidden / 2}, + torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device())); // NVFP4 SF: K/16 bytes per token, packed as K/64 int32 auto x_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_sf_buffer.base)), @@ -108,20 +110,22 @@ get_symm_buffer_size_for_nvfp4_mega_moe( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_weights_buffer.base)), {num_max_tokens_per_rank, num_topk}, torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device())); + // NVFP4: L1 output acts are E2M1 packed, intermediate_hidden/2 bytes auto l1_acts = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_token_buffer.base)), - {num_max_pool_tokens, hidden}, - torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + {num_max_pool_tokens, hidden / 2}, + torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device())); // NVFP4 L1 SF: M-major, K/64 int32 auto l1_acts_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_sf_buffer.base)), {num_max_padded_sf_pool_tokens, hidden / 64}, {1, num_max_padded_sf_pool_tokens}, torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + // NVFP4: L2 acts are E2M1 packed, intermediate_hidden/2 bytes auto l2_acts = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_token_buffer.base)), - {num_max_pool_tokens, intermediate_hidden}, - torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + {num_max_pool_tokens, intermediate_hidden / 2}, + torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device())); // NVFP4 L2 SF: M-major, K/64 int32 auto l2_acts_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_sf_buffer.base)), diff --git a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp index cf91d35..ecdbabf 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp @@ -136,11 +136,12 @@ static void sm100_fp8_nvfp4_mega_moe( // Make tensormap — weight/activation TMA descriptors are the same as MXFP4 // (E2M1 packed uint8 is the same format regardless of scale type) + // NVFP4: activations are E2M1 packed uint8, so K-dim is hidden/2 bytes const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts, - hidden, config.num_max_pool_tokens, - config.block_k, config.load_block_m, + hidden / 2, config.num_max_pool_tokens, + config.block_k / 2, config.load_block_m, static_cast(l1_acts.stride(-2)), - config.swizzle_acts_mode); + config.swizzle_acts_mode / 2); // NVFP4 SF TMA: kGranK=16, so SF K-dim is hidden/16, packed as hidden/64 int32 const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf, config.num_padded_sf_pool_tokens, hidden, @@ -155,16 +156,18 @@ static void sm100_fp8_nvfp4_mega_moe( intermediate_hidden * 2, hidden, config.block_n, kGranK, num_experts_per_rank, 0); + // NVFP4: L1 output is E2M1 packed, intermediate_hidden/2 bytes per row const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts, - intermediate_hidden, config.num_max_pool_tokens, - config.block_n / 2, config.store_block_m, + intermediate_hidden / 2, config.num_max_pool_tokens, + config.block_n / 4, config.store_block_m, static_cast(l2_acts.stride(-2)), - config.swizzle_acts_mode / 2); + config.swizzle_acts_mode / 4); + // NVFP4: L2 activations are E2M1 packed, intermediate_hidden/2 bytes per row const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts, - intermediate_hidden, config.num_max_pool_tokens, - config.block_k, config.load_block_m, + intermediate_hidden / 2, config.num_max_pool_tokens, + config.block_k / 2, config.load_block_m, static_cast(l2_acts.stride(-2)), - config.swizzle_acts_mode); + config.swizzle_acts_mode / 2); const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf, config.num_padded_sf_pool_tokens, intermediate_hidden, config.sf_block_m, kGranK, diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index a2170e5..321d4e1 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -95,12 +95,14 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); // Token and buffer layouts - constexpr auto fp8_token_layout = layout::Data(kHidden); + // NVFP4: activations are E2M1 packed (2 per byte), so K/2 bytes per token + constexpr auto fp4_token_layout = layout::Data(kHidden / 2); constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); - constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); + // NVFP4 intermediate: same E2M1 packing + constexpr auto fp4_intermediate_token_layout = layout::Data(kIntermediateHidden / 2); // NVFP4: group_size=16, so SF stride is K/16 (twice as many scales as MXFP4) - constexpr auto fp8_sf_layout = layout::Data(kHidden / 16); - constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16); + constexpr auto nvfp4_sf_layout = layout::Data(kHidden / 16); + constexpr auto nvfp4_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16); constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false); @@ -164,8 +166,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, ); // Data types - // NOTES: activations are FP8 (e4m3), weights are FP4 (e2m1) - using a_dtype_t = cutlass::float_e4m3_t; + // NVFP4: mxf4nvf4 requires BOTH A and B to be FP4 (E2M1 packed) + using a_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; // MMA configs @@ -173,7 +175,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, constexpr uint32_t LAYOUT_AD_M = 128; constexpr uint32_t UMMA_M = LAYOUT_AD_M * 2; constexpr uint32_t UMMA_N = BLOCK_M; // Swap AB - constexpr uint32_t UMMA_K = 32; + constexpr uint32_t UMMA_K = 64; // FP4: 64 elements per MMA atom (was 32 for FP8) constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M"); @@ -181,8 +183,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); // Swizzle configs - constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); - constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); + // NVFP4: float_e2m1_unpacksmem_t uses 1 byte per element in SMEM (unpacked) + // Same byte stride as FP8 — the TMA hardware unpacks 2 E2M1 values per global byte + constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); // 128 + constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); // 128 constexpr uint32_t kSwizzleCDMode = 128; DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N"); @@ -200,13 +204,16 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); constexpr uint32_t SMEM_SEND_BUFFER_SIZE = - math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); + math::constexpr_align(fp4_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); + // NVFP4: float_e2m1_unpacksmem_t uses 1 byte per element in SMEM (unpacked) + // TMA unpacks 2 E2M1 per global byte into 1 byte per element in SMEM constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + // L1 output: E2M1 packed FP4 (2 per byte) constexpr uint32_t SMEM_CD_L1_SIZE = - kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) * kNumTMAStoreStages; + kNumEpilogueWarpgroups * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2) * sizeof(uint8_t) * kNumTMAStoreStages; constexpr uint32_t SMEM_CD_L2_SIZE = kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16); constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE; @@ -1051,7 +1058,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, ptx::tma_store_wait(); ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); - // Cast to FP8 E4M3 and store into shared memory + // Quantize to E2M1 FP4 and store into shared memory + // NVFP4: mxf4nvf4 requires FP4×FP4, so L1 output is E2M1 packed + // Scale for FP4: scale = amax / 6.0 (E2M1 max value) + // UE4M3 scale already computed below (same as FP8 case but using /6) #pragma unroll for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { // Reduce amax @@ -1060,47 +1070,73 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x); amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y); - // Calculate SF + // Calculate SF for E2M1: scale = amax / 6.0 + // UE4M3 format (same computation as FP8 but different scale base) float2 sf, sf_inv; - math::get_e4m3_sf_and_sf_inv(amax_values[i], sf, sf_inv); + // Use amax/6.0 as the scale (E2M1 max = 6) + sf.x = fmaxf(amax_values[i].x / 6.0f, 1e-8f); + sf.y = fmaxf(amax_values[i].y / 6.0f, 1e-8f); + sf_inv.x = 1.0f / sf.x; + sf_inv.y = 1.0f / sf.y; - // Cast - const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv); - const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv); - const auto fp8x4_values = __nv_fp8x4_e4m3(make_float4(upper.x, upper.y, lower.x, lower.y)); + // E2M1 FP4 quantization: find nearest from [0, 0.5, 1, 1.5, 2, 3, 4, 6] + // Process 4 BF16 values -> 4 E2M1 4-bit values -> pack into 2 bytes + auto quant_e2m1 = [](float v, float scale_inv) -> uint8_t { + float q = v * scale_inv; + q = fmaxf(-6.0f, fminf(6.0f, q)); + uint8_t sign = (q < 0.0f) ? 1 : 0; + float aq = fabsf(q); + // Nearest E2M1 index + uint8_t idx; + if (aq < 0.25f) idx = 0; + else if (aq < 0.75f) idx = 1; + else if (aq < 1.25f) idx = 2; + else if (aq < 1.75f) idx = 3; + else if (aq < 2.5f) idx = 4; + else if (aq < 3.5f) idx = 5; + else if (aq < 5.0f) idx = 6; + else idx = 7; + return (sign << 3) | idx; // 4-bit packed + }; - // STSM - uint32_t row = lane_idx; + // Quantize 4 BF16 values -> 4 E2M1 nibbles + float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv); + float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv); + uint8_t e0 = quant_e2m1(upper.x, 1.0f); + uint8_t e1 = quant_e2m1(upper.y, 1.0f); + uint8_t e2 = quant_e2m1(lower.x, 1.0f); + uint8_t e3 = quant_e2m1(lower.y, 1.0f); + // Pack 2 nibbles per byte: (e1<<4)|e0, (e3<<4)|e2 + uint8_t b0 = (e1 << 4) | e0; + uint8_t b1 = (e3 << 4) | e2; + + // Store packed FP4 bytes to SMEM (row-major, L1_OUT_BLOCK_N/2 bytes per row) + uint32_t row = lane_idx; // lane maps to row within ATOM_M uint32_t col = warp_idx_in_wg; - const auto smem_ptr = smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N - + i * ATOM_M * L1_OUT_BLOCK_N - + row * L1_OUT_BLOCK_N - + (col ^ (row / 2)) * kNumBankGroupBytes; - ptx::SM100_U8x4_STSM_T<__nv_fp8x4_e4m3>::copy(fp8x4_values, smem_ptr); + // SMEM layout: simple row-major, L1_OUT_BLOCK_N/2 bytes per row + // Each STSM atom wrote 4 FP8 values = 4 bytes. Now we write 2 bytes (4 FP4 values packed). + // Column offset: each warp handles 4 bytes worth of N (was 4 FP8 = 4 bytes, now 4 FP4 = 2 bytes) + const auto smem_base = smem_cd[tma_stage_idx] + + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2) + + i * ATOM_M * (L1_OUT_BLOCK_N / 2); + // Bank-conflict-free addressing: interleave with 4-byte offset per warp + uint32_t byte_col = col * 2; // 2 bytes per warp per row + auto smem_ptr = smem_base + row * (L1_OUT_BLOCK_N / 2) + byte_col; + // Write 2 packed FP4 bytes + smem_ptr[0] = b0; + smem_ptr[1] = b1; - // Store SF to `l2_sf_buffer` as UE8M0 (MN-major layout) - // Only one warp per pair writes (both hold the same SF after cross-warp reduce) - // Each lane < 4 holds SF for 2 rows (sf.x and sf.y) + // Store SF to `l2_sf_buffer` as UE4M3 (MN-major layout) if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) { const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2; const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4; const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t); const auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); - // NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4 - // NOTES: originally there was: - // - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2 - // - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)` - // We find out that - // 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside - // 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside - // This reduce the number of computation instructions. const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; __builtin_assume(token_base_idx < BLOCK_M); const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4; const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast(sizeof(uint32_t)) + byte_idx; - // NVFP4: convert float scale to UE4M3 format - // UE4M3: sign=0 + 4 exp + 3 mantissa, max=448 auto to_ue4m3 = [](float v) -> uint8_t { v = fmaxf(0.0f, fminf(v, 448.0f)); cutlass::float_e4m3_t e4m3_val = cutlass::float_e4m3_t(v); @@ -1115,11 +1151,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Issue TMA store after all atoms in this store block if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { - uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N / 2; // FP4: byte offset = element offset / 2 cute::tma_store_fence(); cute::SM90_TMA_STORE_2D::copy( &tensor_map_l1_output, - smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N, + smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2), out_n_idx, m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M); cute::tma_store_arrive(); diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 7808432..c4940c6 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -312,7 +312,9 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor, """NVFP4 mega MoE: uses kind::mxf4nvf4.block_scale.scale_vec::4X with UE4M3 block scales (group_size=16). + Both activations AND weights are E2M1 packed (FP4×FP4). Weight format: (uint8 E2M1 packed, int32 packed UTCCP UE4M3 scales) + Activation format: E2M1 packed uint8 + UE4M3 scales (computed by staging kernel) Recipe: (1, 1, 16) — kGranK=16 for NVFP4 group_size=16. """ _C.fp8_nvfp4_mega_moe(