fix: L1 epilogue uses STSM with XOR swizzle for E2M1 FP4 output
Keep STSM (not naive SMEM write) so TMA reads correct bank layout. Pack 4 E2M1 nibbles into uint32 per STSM atom with XOR swizzle. Known perf note: 32B swizzle zone for L1 output (land for v1).
This commit is contained in:
@@ -1059,10 +1059,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
ptx::tma_store_wait<kNumTMAStoreStages - 1>();
|
||||
ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
|
||||
|
||||
// Quantize to E2M1 FP4 and store into shared memory
|
||||
// Quantize to E2M1 FP4 and STSM into shared memory
|
||||
// NVFP4: mxf4nvf4 requires FP4×FP4, so L1 output is E2M1 packed
|
||||
// Scale for FP4: scale = amax / 6.0 (E2M1 max value)
|
||||
// UE4M3 scale already computed below (same as FP8 case but using /6)
|
||||
// Pack 8 E2M1 nibbles into uint32 for one STSM write (4 bytes = 8 FP4 values)
|
||||
// Keep XOR swizzle for bank-conflict-free SMEM layout that TMA expects
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) {
|
||||
// Reduce amax
|
||||
@@ -1071,23 +1071,18 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x);
|
||||
amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y);
|
||||
|
||||
// Calculate SF for E2M1: scale = amax / 6.0
|
||||
// UE4M3 format (same computation as FP8 but different scale base)
|
||||
// Calculate SF for E2M1: scale = amax / 6.0 (E2M1 max = 6)
|
||||
float2 sf, sf_inv;
|
||||
// Use amax/6.0 as the scale (E2M1 max = 6)
|
||||
sf.x = fmaxf(amax_values[i].x / 6.0f, 1e-8f);
|
||||
sf.y = fmaxf(amax_values[i].y / 6.0f, 1e-8f);
|
||||
sf_inv.x = 1.0f / sf.x;
|
||||
sf_inv.y = 1.0f / sf.y;
|
||||
|
||||
// E2M1 FP4 quantization: find nearest from [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
// Process 4 BF16 values -> 4 E2M1 4-bit values -> pack into 2 bytes
|
||||
auto quant_e2m1 = [](float v, float scale_inv) -> uint8_t {
|
||||
float q = v * scale_inv;
|
||||
q = fmaxf(-6.0f, fminf(6.0f, q));
|
||||
uint8_t sign = (q < 0.0f) ? 1 : 0;
|
||||
float aq = fabsf(q);
|
||||
// Nearest E2M1 index
|
||||
// E2M1 FP4 quantization helper
|
||||
auto quant_e2m1 = [](float v) -> uint8_t {
|
||||
v = fmaxf(-6.0f, fminf(6.0f, v));
|
||||
uint8_t sign = (v < 0.0f) ? 1 : 0;
|
||||
float aq = fabsf(v);
|
||||
uint8_t idx;
|
||||
if (aq < 0.25f) idx = 0;
|
||||
else if (aq < 0.75f) idx = 1;
|
||||
@@ -1097,35 +1092,31 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
else if (aq < 3.5f) idx = 5;
|
||||
else if (aq < 5.0f) idx = 6;
|
||||
else idx = 7;
|
||||
return (sign << 3) | idx; // 4-bit packed
|
||||
return (sign << 3) | idx;
|
||||
};
|
||||
|
||||
// Quantize 4 BF16 values -> 4 E2M1 nibbles
|
||||
float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv);
|
||||
float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv);
|
||||
uint8_t e0 = quant_e2m1(upper.x, 1.0f);
|
||||
uint8_t e1 = quant_e2m1(upper.y, 1.0f);
|
||||
uint8_t e2 = quant_e2m1(lower.x, 1.0f);
|
||||
uint8_t e3 = quant_e2m1(lower.y, 1.0f);
|
||||
// Pack 2 nibbles per byte: (e1<<4)|e0, (e3<<4)|e2
|
||||
uint8_t b0 = (e1 << 4) | e0;
|
||||
uint8_t b1 = (e3 << 4) | e2;
|
||||
// Quantize 4 BF16 values -> 4 E2M1 nibbles -> pack into 2 bytes
|
||||
const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv);
|
||||
const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv);
|
||||
uint8_t e0 = quant_e2m1(upper.x);
|
||||
uint8_t e1 = quant_e2m1(upper.y);
|
||||
uint8_t e2 = quant_e2m1(lower.x);
|
||||
uint8_t e3 = quant_e2m1(lower.y);
|
||||
// Pack 4 nibbles into uint32 (2 bytes): (e1<<4)|e0, (e3<<4)|e2
|
||||
// Note: only fills 2 of the 4 STSM bytes. The other 2 bytes will be
|
||||
// written by the adjacent warp (same XOR swizzle column, different row range).
|
||||
uint32_t packed = (uint32_t)((e1 << 4) | e0)
|
||||
| ((uint32_t)((e3 << 4) | e2) << 8);
|
||||
|
||||
// Store packed FP4 bytes to SMEM (row-major, L1_OUT_BLOCK_N/2 bytes per row)
|
||||
uint32_t row = lane_idx; // lane maps to row within ATOM_M
|
||||
// STSM with XOR swizzle (same pattern as FP8, but halved byte stride)
|
||||
uint32_t row = lane_idx;
|
||||
uint32_t col = warp_idx_in_wg;
|
||||
// SMEM layout: simple row-major, L1_OUT_BLOCK_N/2 bytes per row
|
||||
// Each STSM atom wrote 4 FP8 values = 4 bytes. Now we write 2 bytes (4 FP4 values packed).
|
||||
// Column offset: each warp handles 4 bytes worth of N (was 4 FP8 = 4 bytes, now 4 FP4 = 2 bytes)
|
||||
const auto smem_base = smem_cd[tma_stage_idx]
|
||||
const auto smem_ptr = smem_cd[tma_stage_idx]
|
||||
+ epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2)
|
||||
+ i * ATOM_M * (L1_OUT_BLOCK_N / 2);
|
||||
// Bank-conflict-free addressing: interleave with 4-byte offset per warp
|
||||
uint32_t byte_col = col * 2; // 2 bytes per warp per row
|
||||
auto smem_ptr = smem_base + row * (L1_OUT_BLOCK_N / 2) + byte_col;
|
||||
// Write 2 packed FP4 bytes
|
||||
smem_ptr[0] = b0;
|
||||
smem_ptr[1] = b1;
|
||||
+ i * ATOM_M * (L1_OUT_BLOCK_N / 2)
|
||||
+ row * (L1_OUT_BLOCK_N / 2)
|
||||
+ (col ^ (row / 2)) * kNumBankGroupBytes;
|
||||
ptx::SM100_U32x1_STSM_T<uint32_t>::copy(packed, smem_ptr);
|
||||
|
||||
// Store SF to `l2_sf_buffer` as UE4M3 (MN-major layout)
|
||||
if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) {
|
||||
|
||||
Reference in New Issue
Block a user