fix: L1 epilogue uses STSM with XOR swizzle for E2M1 FP4 output

Keep STSM (not naive SMEM write) so TMA reads correct bank layout.
Pack 4 E2M1 nibbles into uint32 per STSM atom with XOR swizzle.
Known perf note: 32B swizzle zone for L1 output (land for v1).
This commit is contained in:
2026-05-11 20:57:34 +00:00
parent a554de8b24
commit 091b974736

View File

@@ -1059,10 +1059,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
ptx::tma_store_wait<kNumTMAStoreStages - 1>();
ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
// Quantize to E2M1 FP4 and store into shared memory
// Quantize to E2M1 FP4 and STSM into shared memory
// NVFP4: mxf4nvf4 requires FP4×FP4, so L1 output is E2M1 packed
// Scale for FP4: scale = amax / 6.0 (E2M1 max value)
// UE4M3 scale already computed below (same as FP8 case but using /6)
// Pack 8 E2M1 nibbles into uint32 for one STSM write (4 bytes = 8 FP4 values)
// Keep XOR swizzle for bank-conflict-free SMEM layout that TMA expects
#pragma unroll
for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) {
// Reduce amax
@@ -1071,23 +1071,18 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x);
amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y);
// Calculate SF for E2M1: scale = amax / 6.0
// UE4M3 format (same computation as FP8 but different scale base)
// Calculate SF for E2M1: scale = amax / 6.0 (E2M1 max = 6)
float2 sf, sf_inv;
// Use amax/6.0 as the scale (E2M1 max = 6)
sf.x = fmaxf(amax_values[i].x / 6.0f, 1e-8f);
sf.y = fmaxf(amax_values[i].y / 6.0f, 1e-8f);
sf_inv.x = 1.0f / sf.x;
sf_inv.y = 1.0f / sf.y;
// E2M1 FP4 quantization: find nearest from [0, 0.5, 1, 1.5, 2, 3, 4, 6]
// Process 4 BF16 values -> 4 E2M1 4-bit values -> pack into 2 bytes
auto quant_e2m1 = [](float v, float scale_inv) -> uint8_t {
float q = v * scale_inv;
q = fmaxf(-6.0f, fminf(6.0f, q));
uint8_t sign = (q < 0.0f) ? 1 : 0;
float aq = fabsf(q);
// Nearest E2M1 index
// E2M1 FP4 quantization helper
auto quant_e2m1 = [](float v) -> uint8_t {
v = fmaxf(-6.0f, fminf(6.0f, v));
uint8_t sign = (v < 0.0f) ? 1 : 0;
float aq = fabsf(v);
uint8_t idx;
if (aq < 0.25f) idx = 0;
else if (aq < 0.75f) idx = 1;
@@ -1097,35 +1092,31 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
else if (aq < 3.5f) idx = 5;
else if (aq < 5.0f) idx = 6;
else idx = 7;
return (sign << 3) | idx; // 4-bit packed
return (sign << 3) | idx;
};
// Quantize 4 BF16 values -> 4 E2M1 nibbles
float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv);
float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv);
uint8_t e0 = quant_e2m1(upper.x, 1.0f);
uint8_t e1 = quant_e2m1(upper.y, 1.0f);
uint8_t e2 = quant_e2m1(lower.x, 1.0f);
uint8_t e3 = quant_e2m1(lower.y, 1.0f);
// Pack 2 nibbles per byte: (e1<<4)|e0, (e3<<4)|e2
uint8_t b0 = (e1 << 4) | e0;
uint8_t b1 = (e3 << 4) | e2;
// Quantize 4 BF16 values -> 4 E2M1 nibbles -> pack into 2 bytes
const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv);
const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv);
uint8_t e0 = quant_e2m1(upper.x);
uint8_t e1 = quant_e2m1(upper.y);
uint8_t e2 = quant_e2m1(lower.x);
uint8_t e3 = quant_e2m1(lower.y);
// Pack 4 nibbles into uint32 (2 bytes): (e1<<4)|e0, (e3<<4)|e2
// Note: only fills 2 of the 4 STSM bytes. The other 2 bytes will be
// written by the adjacent warp (same XOR swizzle column, different row range).
uint32_t packed = (uint32_t)((e1 << 4) | e0)
| ((uint32_t)((e3 << 4) | e2) << 8);
// Store packed FP4 bytes to SMEM (row-major, L1_OUT_BLOCK_N/2 bytes per row)
uint32_t row = lane_idx; // lane maps to row within ATOM_M
// STSM with XOR swizzle (same pattern as FP8, but halved byte stride)
uint32_t row = lane_idx;
uint32_t col = warp_idx_in_wg;
// SMEM layout: simple row-major, L1_OUT_BLOCK_N/2 bytes per row
// Each STSM atom wrote 4 FP8 values = 4 bytes. Now we write 2 bytes (4 FP4 values packed).
// Column offset: each warp handles 4 bytes worth of N (was 4 FP8 = 4 bytes, now 4 FP4 = 2 bytes)
const auto smem_base = smem_cd[tma_stage_idx]
const auto smem_ptr = smem_cd[tma_stage_idx]
+ epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2)
+ i * ATOM_M * (L1_OUT_BLOCK_N / 2);
// Bank-conflict-free addressing: interleave with 4-byte offset per warp
uint32_t byte_col = col * 2; // 2 bytes per warp per row
auto smem_ptr = smem_base + row * (L1_OUT_BLOCK_N / 2) + byte_col;
// Write 2 packed FP4 bytes
smem_ptr[0] = b0;
smem_ptr[1] = b1;
+ i * ATOM_M * (L1_OUT_BLOCK_N / 2)
+ row * (L1_OUT_BLOCK_N / 2)
+ (col ^ (row / 2)) * kNumBankGroupBytes;
ptx::SM100_U32x1_STSM_T<uint32_t>::copy(packed, smem_ptr);
// Store SF to `l2_sf_buffer` as UE4M3 (MN-major layout)
if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) {