fix quantize_nvfp4 kernel: use proven single-thread-per-CTA pattern from deinterleave_quantize.cu
The warp shuffle approach failed because __shfl_down_sync with 16 threads has undefined behavior for the odd nibble. Use the same pattern as the working deinterleave_quantize.cu: 1 CTA per 16-element block, 16 threads per CTA, each thread reads all 16 elements sequentially and computes amax + quantize + pack.
This commit is contained in:
@@ -14,11 +14,11 @@
|
||||
// Global scale is passed in as a pre-computed scalar.
|
||||
//
|
||||
// Grid: (N / 16, M, 1) — each CTA processes one 16-element block in one row.
|
||||
// Block: 16 threads (1 thread per element in the 16-element microblock).
|
||||
// Block: 16 threads (1 thread per element, warp amax reduction).
|
||||
//
|
||||
// Same proven pattern as deinterleave_quantize.cu.
|
||||
|
||||
__device__ __forceinline__ int half_step_to_e2m1(int hs) {
|
||||
// Matches Python step_to_idx LUT:
|
||||
// 0→0, 1→1, 2→2, 3→3, 4→4, 5→4, 6→5, 7→5, 8→6, 9→6, 10→6, 11→7, 12→7
|
||||
if (hs <= 4) return hs;
|
||||
if (hs <= 5) return 4;
|
||||
if (hs <= 7) return 5;
|
||||
@@ -37,58 +37,48 @@ __global__ void quantize_nvfp4_kernel(
|
||||
int n_block = blockIdx.x;
|
||||
if (m >= M || n_block * 16 >= N) return;
|
||||
|
||||
int lane = threadIdx.x; // 0..15
|
||||
float vals[16];
|
||||
float block_amax = 0.0f;
|
||||
|
||||
// Step 1: Read 1 BF16 element per thread, normalize by global_scale
|
||||
int col = n_block * 16 + lane;
|
||||
float val = 0.0f;
|
||||
if (col < N) {
|
||||
val = __bfloat162float(input[m * N + col]);
|
||||
val = val / global_scale;
|
||||
// Step 1: Read 16 BF16 elements and compute amax
|
||||
for (int i = 0; i < 16; i++) {
|
||||
int col = n_block * 16 + i;
|
||||
if (col < N) {
|
||||
vals[i] = __bfloat162float(input[m * N + col]) / global_scale;
|
||||
} else {
|
||||
vals[i] = 0;
|
||||
}
|
||||
block_amax = fmaxf(block_amax, fabsf(vals[i]));
|
||||
}
|
||||
|
||||
// Step 2: Warp-level amax reduction (16 threads = half-warp)
|
||||
// Use warp shuffle for the reduction
|
||||
float abs_val = fabsf(val);
|
||||
for (int offset = 8; offset > 0; offset >>= 1) {
|
||||
abs_val = fmaxf(abs_val, __shfl_down_sync(0xFFFF, abs_val, offset));
|
||||
}
|
||||
float block_amax = abs_val; // Same value in all 16 lanes after reduction
|
||||
|
||||
// Step 3: Compute FP8 E4M3 block scale = amax / 6.0
|
||||
// Step 2: Compute FP8 E4M3 block scale
|
||||
float bsf = block_amax / 6.0f;
|
||||
if (block_amax < 6.0f * 0.001953125f) { // FP8 E4M3 underflow threshold
|
||||
if (block_amax < 6.0f * 0.001953125f) {
|
||||
bsf = 0;
|
||||
val = 0;
|
||||
for (int i = 0; i < 16; i++) vals[i] = 0;
|
||||
}
|
||||
__nv_fp8_e4m3 bsf8_obj(bsf);
|
||||
float bs = (float)bsf8_obj;
|
||||
uint8_t bsf8 = *(uint8_t*)&bsf8_obj;
|
||||
|
||||
// Step 4: Quantize each value to FP4 E2M1
|
||||
uint8_t nibble = 0;
|
||||
if (bs >= 1e-8f) {
|
||||
float s = val / bs;
|
||||
// Step 3: Quantize each value to FP4 E2M1
|
||||
uint8_t nibbles[16];
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (bs < 1e-8f) { nibbles[i] = 0; continue; }
|
||||
float s = vals[i] / bs;
|
||||
int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f);
|
||||
if (hs > 12) hs = 12;
|
||||
int idx = half_step_to_e2m1(hs);
|
||||
if (s < 0) idx += 8;
|
||||
nibble = idx;
|
||||
nibbles[i] = idx;
|
||||
}
|
||||
|
||||
// Step 5: Pack pairs of FP4 nibbles into bytes
|
||||
// Even lanes write the packed byte: (odd_nibble << 4) | even_nibble
|
||||
if (lane % 2 == 0) {
|
||||
uint8_t odd_nibble = __shfl_down_sync(0xFFFF, nibble, 1);
|
||||
uint8_t packed = (odd_nibble << 4) | nibble;
|
||||
int byte_idx = m * (N / 2) + n_block * 8 + (lane / 2);
|
||||
out_fp4[byte_idx] = packed;
|
||||
}
|
||||
// Step 4: Pack pairs: (nibbles[1] << 4) | nibbles[0], etc.
|
||||
for (int i = 0; i < 8; i++)
|
||||
out_fp4[m * (N / 2) + n_block * 8 + i] = (nibbles[2*i+1] << 4) | nibbles[2*i];
|
||||
|
||||
// Step 6: Write FP8 block scale (1 thread per block)
|
||||
if (lane == 0) {
|
||||
out_sf[m * (N / 16) + n_block] = bsf8;
|
||||
}
|
||||
// Step 5: Write FP8 block scale
|
||||
out_sf[m * (N / 16) + n_block] = bsf8;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> quantize_nvfp4_cuda(
|
||||
|
||||
Reference in New Issue
Block a user