From 091b974736ceef6d8c791482bb69b849856924d5 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 20:57:34 +0000 Subject: [PATCH] fix: L1 epilogue uses STSM with XOR swizzle for E2M1 FP4 output Keep STSM (not naive SMEM write) so TMA reads correct bank layout. Pack 4 E2M1 nibbles into uint32 per STSM atom with XOR swizzle. Known perf note: 32B swizzle zone for L1 output (land for v1). --- .../impls/sm100_fp8_nvfp4_mega_moe.cuh | 67 ++++++++----------- 1 file changed, 29 insertions(+), 38 deletions(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index d48dada..66d44b7 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -1059,10 +1059,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, ptx::tma_store_wait(); ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); - // Quantize to E2M1 FP4 and store into shared memory + // Quantize to E2M1 FP4 and STSM into shared memory // NVFP4: mxf4nvf4 requires FP4×FP4, so L1 output is E2M1 packed - // Scale for FP4: scale = amax / 6.0 (E2M1 max value) - // UE4M3 scale already computed below (same as FP8 case but using /6) + // Pack 8 E2M1 nibbles into uint32 for one STSM write (4 bytes = 8 FP4 values) + // Keep XOR swizzle for bank-conflict-free SMEM layout that TMA expects #pragma unroll for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { // Reduce amax @@ -1071,23 +1071,18 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x); amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y); - // Calculate SF for E2M1: scale = amax / 6.0 - // UE4M3 format (same computation as FP8 but different scale base) + // Calculate SF for E2M1: scale = amax / 6.0 (E2M1 max = 6) float2 sf, sf_inv; - // Use amax/6.0 as the scale (E2M1 max = 6) sf.x = fmaxf(amax_values[i].x / 6.0f, 1e-8f); sf.y = fmaxf(amax_values[i].y / 6.0f, 1e-8f); sf_inv.x = 1.0f / sf.x; sf_inv.y = 1.0f / sf.y; - // E2M1 FP4 quantization: find nearest from [0, 0.5, 1, 1.5, 2, 3, 4, 6] - // Process 4 BF16 values -> 4 E2M1 4-bit values -> pack into 2 bytes - auto quant_e2m1 = [](float v, float scale_inv) -> uint8_t { - float q = v * scale_inv; - q = fmaxf(-6.0f, fminf(6.0f, q)); - uint8_t sign = (q < 0.0f) ? 1 : 0; - float aq = fabsf(q); - // Nearest E2M1 index + // E2M1 FP4 quantization helper + auto quant_e2m1 = [](float v) -> uint8_t { + v = fmaxf(-6.0f, fminf(6.0f, v)); + uint8_t sign = (v < 0.0f) ? 1 : 0; + float aq = fabsf(v); uint8_t idx; if (aq < 0.25f) idx = 0; else if (aq < 0.75f) idx = 1; @@ -1097,35 +1092,31 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, else if (aq < 3.5f) idx = 5; else if (aq < 5.0f) idx = 6; else idx = 7; - return (sign << 3) | idx; // 4-bit packed + return (sign << 3) | idx; }; - // Quantize 4 BF16 values -> 4 E2M1 nibbles - float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv); - float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv); - uint8_t e0 = quant_e2m1(upper.x, 1.0f); - uint8_t e1 = quant_e2m1(upper.y, 1.0f); - uint8_t e2 = quant_e2m1(lower.x, 1.0f); - uint8_t e3 = quant_e2m1(lower.y, 1.0f); - // Pack 2 nibbles per byte: (e1<<4)|e0, (e3<<4)|e2 - uint8_t b0 = (e1 << 4) | e0; - uint8_t b1 = (e3 << 4) | e2; + // Quantize 4 BF16 values -> 4 E2M1 nibbles -> pack into 2 bytes + const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv); + const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv); + uint8_t e0 = quant_e2m1(upper.x); + uint8_t e1 = quant_e2m1(upper.y); + uint8_t e2 = quant_e2m1(lower.x); + uint8_t e3 = quant_e2m1(lower.y); + // Pack 4 nibbles into uint32 (2 bytes): (e1<<4)|e0, (e3<<4)|e2 + // Note: only fills 2 of the 4 STSM bytes. The other 2 bytes will be + // written by the adjacent warp (same XOR swizzle column, different row range). + uint32_t packed = (uint32_t)((e1 << 4) | e0) + | ((uint32_t)((e3 << 4) | e2) << 8); - // Store packed FP4 bytes to SMEM (row-major, L1_OUT_BLOCK_N/2 bytes per row) - uint32_t row = lane_idx; // lane maps to row within ATOM_M + // STSM with XOR swizzle (same pattern as FP8, but halved byte stride) + uint32_t row = lane_idx; uint32_t col = warp_idx_in_wg; - // SMEM layout: simple row-major, L1_OUT_BLOCK_N/2 bytes per row - // Each STSM atom wrote 4 FP8 values = 4 bytes. Now we write 2 bytes (4 FP4 values packed). - // Column offset: each warp handles 4 bytes worth of N (was 4 FP8 = 4 bytes, now 4 FP4 = 2 bytes) - const auto smem_base = smem_cd[tma_stage_idx] + const auto smem_ptr = smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2) - + i * ATOM_M * (L1_OUT_BLOCK_N / 2); - // Bank-conflict-free addressing: interleave with 4-byte offset per warp - uint32_t byte_col = col * 2; // 2 bytes per warp per row - auto smem_ptr = smem_base + row * (L1_OUT_BLOCK_N / 2) + byte_col; - // Write 2 packed FP4 bytes - smem_ptr[0] = b0; - smem_ptr[1] = b1; + + i * ATOM_M * (L1_OUT_BLOCK_N / 2) + + row * (L1_OUT_BLOCK_N / 2) + + (col ^ (row / 2)) * kNumBankGroupBytes; + ptx::SM100_U32x1_STSM_T::copy(packed, smem_ptr); // Store SF to `l2_sf_buffer` as UE4M3 (MN-major layout) if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) {