Files
nvfp4-megamoe-kernel/dsv4/kernels/cuda/kv_quantize.cu

373 lines
13 KiB
Plaintext

/**
* Quantize FP32 tensor to NVFP4.
*
* Same proven pattern as quantize_nvfp4.cu (which reads BF16),
* but takes FP32 input directly — avoids BF16 intermediate.
*
* This is the correct path for compressor output → NVFP4:
* Compressor produces FP32 → this kernel → NVFP4 stored in KV cache
* No BF16 anywhere in the pipeline.
*
* Two-kernel approach (proven correct in fused_amax_quantize.cu):
* Kernel 1: amax_gsa_fp32 — compute per-row gsa from FP32 input (GPU-only)
* Kernel 2: quantize_nvfp4_from_fp32 — quantize FP32 → NVFP4 using GPU gsa buffer
*
* Grid: (N/16, M, 1) — each CTA processes one 16-element block in one row.
* Block: 16 threads (1 thread per element, warp amax reduction).
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <cuda_fp8.hpp>
#include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <cstdint>
#include <cfloat>
__device__ __forceinline__ int half_step_to_e2m1(int hs) {
if (hs <= 4) return hs;
if (hs <= 5) return 4;
if (hs <= 7) return 5;
if (hs <= 10) return 6;
return 7;
}
// ===========================================================================
// Kernel 1: Compute per-row amax → gsa from FP32 input
// Same pattern as amax_gsa.cu but for FP32 (not BF16) input
// ===========================================================================
__global__ void compute_amax_gsa_fp32_kernel(
const float* __restrict__ input,
int M, int N,
float divisor,
float* __restrict__ out_gsa
) {
int m = blockIdx.x;
if (m >= M) return;
float local_max = 0.0f;
for (int i = threadIdx.x; i < N; i += 256) {
float v = fabsf(input[m * N + i]);
local_max = fmaxf(local_max, v);
}
// Warp-level reduction
for (int offset = 128; offset > 0; offset >>= 1)
local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, offset));
// Block-level reduction using shared memory
__shared__ float s_max[8];
if (threadIdx.x % 32 == 0)
s_max[threadIdx.x / 32] = local_max;
__syncthreads();
if (threadIdx.x < 32) {
float v = (threadIdx.x < 8) ? s_max[threadIdx.x] : 0.0f;
for (int offset = 16; offset > 0; offset >>= 1)
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, offset));
if (threadIdx.x == 0)
out_gsa[m] = v / divisor;
}
}
// ===========================================================================
// Kernel 2: Quantize FP32 → NVFP4 using gsa from GPU buffer
// Same proven pattern as quantize_nvfp4_from_buffer_kernel (fused_amax_quantize.cu)
// but reads FP32 instead of BF16
// ===========================================================================
__global__ void quantize_nvfp4_from_fp32_kernel(
const float* __restrict__ input,
int M, int N,
const float* __restrict__ gsa_buffer, // (M,) GPU buffer with per-row gsa
uint8_t* __restrict__ out_fp4,
uint8_t* __restrict__ out_sf
) {
int m = blockIdx.y;
int n_block = blockIdx.x;
if (m >= M || n_block * 16 >= N) return;
float gsa = gsa_buffer[m];
float vals[16];
float block_amax = 0.0f;
// Step 1: Read 16 FP32 elements and compute block amax
for (int i = 0; i < 16; i++) {
int col = n_block * 16 + i;
if (col < N) {
vals[i] = input[m * N + col] / gsa;
} else {
vals[i] = 0;
}
block_amax = fmaxf(block_amax, fabsf(vals[i]));
}
// Step 2: Compute FP8 E4M3 block scale (with FP8 round-trip)
float bsf = block_amax / 6.0f;
if (block_amax < 6.0f * 0.001953125f) {
// Zero/underflow block
bsf = 0;
for (int i = 0; i < 16; i++) vals[i] = 0;
}
__nv_fp8_e4m3 bsf8_obj(bsf);
float bs = (float)bsf8_obj; // FP8 round-trip — matches dequant
uint8_t bsf8 = *(uint8_t*)&bsf8_obj;
// Step 3: Quantize each value to FP4 E2M1
uint8_t nibbles[16];
for (int i = 0; i < 16; i++) {
if (bs < 1e-8f) { nibbles[i] = 0; continue; }
float s = vals[i] / bs;
int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f);
if (hs > 12) hs = 12;
int idx = half_step_to_e2m1(hs);
if (s < 0) idx += 8;
nibbles[i] = idx;
}
// Step 4: Pack pairs: (nibbles[1] << 4) | nibbles[0], etc.
for (int i = 0; i < 8; i++)
out_fp4[m * (N / 2) + n_block * 8 + i] = (nibbles[2*i+1] << 4) | nibbles[2*i];
// Step 5: Write FP8 block scale
out_sf[m * (N / 16) + n_block] = bsf8;
}
// ===========================================================================
// FP32 GPT-J interleaved RoPE (for compressed KV — no BF16 intermediate)
// Same math as rope_cuda.cu but operates on FP32 directly.
// ===========================================================================
__global__ void rope_fp32_kernel(
float* __restrict__ x, // (M, 1, N) FP32 — modified in-place
const float* __restrict__ cos_c, // (max_pos, rope_dim/2) FP32
const float* __restrict__ sin_c, // (max_pos, rope_dim/2) FP32
const int64_t* __restrict__ pos, // (M,) positions
int N, int rope_dim, bool inverse
) {
int m = blockIdx.x;
if (m >= gridDim.x) return;
int64_t p = pos[m];
int nope = N - rope_dim;
for (int i = threadIdx.x; i < rope_dim / 2; i += 256) {
float c = cos_c[p * (rope_dim / 2) + i];
float s = sin_c[p * (rope_dim / 2) + i];
int ev_idx = m * N + nope + 2 * i;
int od_idx = m * N + nope + 2 * i + 1;
float ev = x[ev_idx];
float od = x[od_idx];
if (inverse) {
x[ev_idx] = ev * c + od * s;
x[od_idx] = -ev * s + od * c;
} else {
x[ev_idx] = ev * c - od * s;
x[od_idx] = ev * s + od * c;
}
}
}
// ===========================================================================
// FP8 E4M3 quantize FP32 → FP8 (for indexer keys — higher precision)
// ===========================================================================
__global__ void quantize_fp8_e4m3_from_fp32_kernel(
const float* __restrict__ input,
int M, int N,
float* __restrict__ out_scale, // (M,) per-row scale
uint8_t* __restrict__ out_fp8 // (M, N) packed FP8 E4M3
) {
int m = blockIdx.x;
if (m >= M) return;
// Per-row amax → scale = amax / 448.0 (E4M3 max = 448)
float local_max = 0.0f;
for (int i = threadIdx.x; i < N; i += 256) {
float v = fabsf(input[m * N + i]);
local_max = fmaxf(local_max, v);
}
for (int offset = 128; offset > 0; offset >>= 1)
local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, offset));
__shared__ float s_max[8];
if (threadIdx.x % 32 == 0) s_max[threadIdx.x / 32] = local_max;
__syncthreads();
if (threadIdx.x < 32) {
float v = (threadIdx.x < 8) ? s_max[threadIdx.x] : 0.0f;
for (int offset = 16; offset > 0; offset >>= 1)
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, offset));
if (threadIdx.x == 0) {
float scale = v / 448.0f;
if (scale < 1e-8f) scale = 1e-8f;
out_scale[m] = scale;
}
}
__syncthreads();
// Quantize each element
float scale = out_scale[m];
float inv_scale = 1.0f / scale;
for (int i = threadIdx.x; i < N; i += 256) {
float v = input[m * N + i] * inv_scale;
v = fmaxf(v, -448.0f);
v = fminf(v, 448.0f);
__nv_fp8_e4m3 obj(v);
out_fp8[m * N + i] = *(uint8_t*)&obj;
}
}
// ===========================================================================
// FP8 E4M3 dequant → BF16 (for indexer key gather)
// ===========================================================================
__global__ void dequant_fp8_e4m3_kernel(
const uint8_t* __restrict__ fp8_data,
const float* __restrict__ scale_data,
int M, int N,
__nv_bfloat16* __restrict__ output
) {
int m = blockIdx.x;
if (m >= M) return;
float scale = scale_data[m];
for (int i = threadIdx.x; i < N; i += 256) {
uint8_t byte = fp8_data[m * N + i];
__nv_fp8_e4m3 val;
memcpy(&val, &byte, 1);
float v = (float)val * scale;
output[m * N + i] = __float2bfloat16(v);
}
}
__global__ void dequant_fp8_e4m3_selective_kernel(
const uint8_t* __restrict__ fp8_data,
const float* __restrict__ scale_data,
const int32_t* __restrict__ indices,
int K, int N,
__nv_bfloat16* __restrict__ output
) {
int k = blockIdx.x;
if (k >= K) return;
int src_row = indices[k];
float scale = scale_data[src_row];
for (int i = threadIdx.x; i < N; i += 256) {
uint8_t byte = fp8_data[src_row * N + i];
__nv_fp8_e4m3 val;
memcpy(&val, &byte, 1);
float v = (float)val * scale;
output[k * N + i] = __float2bfloat16(v);
}
}
// ===========================================================================
// PyTorch bindings
// ===========================================================================
torch::Tensor compute_amax_gsa_fp32_cuda(torch::Tensor input, double divisor) {
int M = input.size(0);
int N = input.size(1);
auto out_gsa = torch::zeros({M}, input.options().dtype(torch::kFloat32));
compute_amax_gsa_fp32_kernel<<<M, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(), M, N, (float)divisor, out_gsa.data_ptr<float>());
return out_gsa;
}
std::tuple<torch::Tensor, torch::Tensor> quantize_nvfp4_from_fp32_cuda(
torch::Tensor input, torch::Tensor gsa_buffer
) {
int M = input.size(0);
int N = input.size(1);
TORCH_CHECK(N % 16 == 0, "N must be a multiple of 16 for NVFP4 quantization");
TORCH_CHECK(gsa_buffer.size(0) == M, "gsa_buffer size must match M");
auto opts = input.options();
auto out_fp4 = torch::zeros({M, N / 2}, opts.dtype(torch::kUInt8));
auto out_sf = torch::zeros({M, N / 16}, opts.dtype(torch::kUInt8));
int nb = N / 16;
dim3 grid(nb, M);
dim3 block(16);
quantize_nvfp4_from_fp32_kernel<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(), M, N, gsa_buffer.data_ptr<float>(),
out_fp4.data_ptr<uint8_t>(), out_sf.data_ptr<uint8_t>()
);
return {out_fp4.view(torch::kFloat4_e2m1fn_x2), out_sf.view(torch::kFloat8_e4m3fn)};
}
std::tuple<torch::Tensor, torch::Tensor> quantize_fp8_e4m3_from_fp32_cuda(
torch::Tensor input
) {
int M = input.size(0);
int N = input.size(1);
auto opts = input.options();
auto out_scale = torch::zeros({M}, opts.dtype(torch::kFloat32));
auto out_fp8 = torch::zeros({M, N}, opts.dtype(torch::kUInt8));
quantize_fp8_e4m3_from_fp32_kernel<<<M, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(), M, N,
out_scale.data_ptr<float>(), out_fp8.data_ptr<uint8_t>()
);
return {out_fp8.view(torch::kFloat8_e4m3fn), out_scale};
}
torch::Tensor dequant_fp8_e4m3_cuda(
torch::Tensor fp8_data, torch::Tensor scale_data
) {
int M = fp8_data.size(0);
int N = fp8_data.size(1);
auto output = torch::zeros({M, N}, fp8_data.options().dtype(torch::kBFloat16));
dequant_fp8_e4m3_kernel<<<M, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
fp8_data.data_ptr<uint8_t>(), scale_data.data_ptr<float>(), M, N,
reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>())
);
return output;
}
torch::Tensor dequant_fp8_e4m3_selective_cuda(
torch::Tensor fp8_data, torch::Tensor scale_data, torch::Tensor indices
) {
int K = indices.size(0);
int N = fp8_data.size(1);
TORCH_CHECK(indices.scalar_type() == torch::kInt32, "indices must be int32");
auto output = torch::zeros({K, N}, fp8_data.options().dtype(torch::kBFloat16));
dequant_fp8_e4m3_selective_kernel<<<K, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
fp8_data.data_ptr<uint8_t>(), scale_data.data_ptr<float>(),
indices.data_ptr<int32_t>(), K, N,
reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>())
);
return output;
}
void rope_fp32_cuda(
torch::Tensor x, // (M, N) FP32 — modified in-place
torch::Tensor positions, // (M,) int64
torch::Tensor cos_cache, // (max_pos, rope_dim/2) FP32
torch::Tensor sin_cache, // (max_pos, rope_dim/2) FP32
int64_t rope_dim,
bool inverse
) {
int M = x.size(0);
int N = x.size(1);
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "x must be float32");
rope_fp32_kernel<<<M, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
cos_cache.data_ptr<float>(),
sin_cache.data_ptr<float>(),
positions.data_ptr<int64_t>(),
N, (int)rope_dim, inverse
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("compute_amax_gsa_fp32", &compute_amax_gsa_fp32_cuda,
"Compute per-row gsa from FP32 input (GPU-only, no CPU sync)");
m.def("quantize_nvfp4_from_fp32", &quantize_nvfp4_from_fp32_cuda,
"Quantize FP32 → NVFP4 using gsa from GPU buffer");
m.def("quantize_fp8_e4m3_from_fp32", &quantize_fp8_e4m3_from_fp32_cuda,
"Quantize FP32 → FP8 E4M3 (for indexer keys)");
m.def("dequant_fp8_e4m3", &dequant_fp8_e4m3_cuda,
"Dequant FP8 E4M3 → BF16");
m.def("dequant_fp8_e4m3_selective", &dequant_fp8_e4m3_selective_cuda,
"Selective dequant FP8 E4M3 → BF16 (for CSA indexer gather)");
m.def("rope_fp32", &rope_fp32_cuda,
"FP32 GPT-J interleaved RoPE (for compressed KV)");
}