Files
nvfp4-megamoe-kernel/dsv4/kernels/cuda/flush_write.cu
biondizzle 0f539e4855 Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation
Schema fix (paper eq.11-12):
  CSA needs m entries for current a-stream AND m entries for previous
  b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush,
  current a-stream becomes next flush b-stream input.
  HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream).
  tail_zb initialized to -1e9 so softmax naturally masks b-stream on
  first flush (paper: Z^b padded with -inf, C^b with zeros).

prepare_forward.py:
  Runs between captured graphs. Computes new compressed entries from
  position delta, pre-allocates blocks before the graph runs.
  Deterministic: entries_after - entries_before, ceil to block boundary.
  No allocation inside the captured graph.

flush_write.cu — 4 kernels:
  flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed
    entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale).
    One block per request, 128 threads. Amax reduction -> inv_scale.
  flush_write_hca_kernel: same minus indexer (no FP4 write).
  csa_rotate_state_kernel: after CSA flush, rotate a->b stream,
    clear a-stream, reset tail_len.
  hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len.

flush.py: Python orchestration.
  maybe_flush_csa/hca: always runs, kernels gate via valid_mask.
  Compressor produces entry, flush kernel quantize-scatters, state
  kernel rotates/resets. No host-side branching for cudagraph.

All tests pass on B200:
  Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0
  State: tail_zb initialized to -1e9, reset_slot preserves it
  prepare_forward: correct block allocation for position transitions
  HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op
  CSA flush write: RoPE exact, indexer FP4 keys written
  CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0
  HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00

451 lines
16 KiB
Plaintext

// flush_write.cu — Quantize and scatter compressed entries into paged KV pool.
//
// Two kernel variants:
// flush_write_csa_kernel: writes compressed entry + FP4 indexer key
// flush_write_hca_kernel: writes compressed entry only (no indexer)
//
// Both do BF16 → FP8 (E4M3) quantization with per-token amax for the
// non-RoPE half, and write the RoPE half as-is BF16.
//
// One block per request. Each block handles writing ONE compressed entry
// per flush. At decode (B small, 1 entry/flush) this is 1-16 CTAs.
// At prefill (B up to 128), this is up to 128 CTAs — good occupancy.
//
// Blackwell SM100: 128 threads per block for the FP8 quantize loop
// covers head_dim=512 with 4 elements per thread. The FP4 indexer
// quantize uses 64 threads (indexer_head_dim=128, 2 elements/thread).
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAException.h>
#include <limits>
// ---- Warp-level reductions ----
__device__ __forceinline__ float warp_reduce_max(float val) {
for (int offset = 16; offset > 0; offset >>= 1) {
float other = __shfl_down_sync(0xffffffff, val, offset);
val = fmaxf(val, fabsf(other));
}
return val;
}
__device__ __forceinline__ float warp_reduce_sum(float val) {
for (int offset = 16; offset > 0; offset >>= 1) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return val;
}
// ---- Block-level amax (128 threads = 4 warps) ----
__device__ __forceinline__ float block_reduce_amax(float val, int n_warps) {
float warp_amax = warp_reduce_max(val);
__shared__ float smem[4];
if (threadIdx.x % 32 == 0) {
smem[threadIdx.x / 32] = warp_amax;
}
__syncthreads();
float result = 0.0f;
if (threadIdx.x < 32) {
float v = (threadIdx.x < n_warps) ? smem[threadIdx.x] : 0.0f;
result = warp_reduce_max(v);
}
__syncthreads();
return result;
}
// ---- NVFP4 quantization for indexer keys ----
// 16-element groups, one E4M3 scale per group.
// FP4 E2M1 has 6 possible values: 0, 2, 4, 6, 8, 10, 12, 14 (shifted).
// We use a simplified approach: group amax / 6.0 -> scale,
// quantize each element to nearest of {0,1,2,3,4,5,6} * scale.
__device__ __forceinline__ void quantize_fp4_group(
const __nv_bfloat16* __restrict__ input, // 16 elements
uint8_t* __restrict__ output, // 8 bytes (2 FP4 per byte)
uint8_t* __restrict__ scale_out // 1 FP8 E4M3 scale
) {
// Compute group amax
float amax = 0.0f;
for (int i = 0; i < 16; i++) {
amax = fmaxf(amax, fabsf(__bfloat162float(input[i])));
}
// FP4 E2M1 has max representable = 6.0 (before scaling)
float scale = amax / 6.0f;
if (scale < 1e-12f) scale = 1e-12f;
float inv_scale = scale;
// Write scale as FP8 E4M3
__nv_fp8_e4m3 fp8_scale;
fp8_scale = __nv_fp8_e4m3(scale);
*scale_out = fp8_scale.__x;
// Quantize 16 elements to FP4 E2M1, pack 2 per byte
for (int i = 0; i < 8; i++) {
float v0 = __bfloat162float(input[2 * i]) / inv_scale;
float v1 = __bfloat162float(input[2 * i + 1]) / inv_scale;
// Clamp to [0, 6] and round to nearest int
int q0 = (int)roundf(fmaxf(0.0f, fminf(6.0f, v0)));
int q1 = (int)roundf(fmaxf(0.0f, fminf(6.0f, v1)));
// Pack: low nibble = element 0, high nibble = element 1
output[i] = (uint8_t)((q1 << 4) | q0);
}
}
// ===========================================================================
// CSA flush write kernel
// ===========================================================================
__global__ void flush_write_csa_kernel(
// Inputs
const __nv_bfloat16* __restrict__ entry, // [B, head_dim] BF16
const __nv_bfloat16* __restrict__ indexer_key, // [B, indexer_head_dim] BF16
const bool* __restrict__ valid_mask, // [B]
const int32_t* __restrict__ request_slots, // [B]
const int32_t* __restrict__ positions, // [B]
const int32_t* __restrict__ block_table, // [B, max_logical_blocks]
// Outputs — paged pool tensors, mutated in place
uint8_t* __restrict__ entries_fp8, // [num_blocks, epb, fp8_dim]
__nv_bfloat16* __restrict__ entries_rope, // [num_blocks, epb, rope_dim]
float* __restrict__ inv_scale, // [num_blocks, epb]
uint8_t* __restrict__ indexer_keys_fp4, // [num_blocks, epb, ihd/2]
uint8_t* __restrict__ indexer_scale, // [num_blocks, epb, ihd/16]
// Geometry
int entries_per_block, int m, int rope_dim,
int head_dim, int indexer_head_dim, int max_logical_blocks
) {
int b = blockIdx.x;
if (!valid_mask[b]) return; // Early exit for no-op requests.
// Resolve destination slot in the paged pool.
int pos = positions[b];
int entry_idx = pos / m; // which compressed entry index
int logical_block = entry_idx / entries_per_block;
int slot_in_block = entry_idx % entries_per_block;
int phys_block = block_table[b * max_logical_blocks + logical_block];
int fp8_dim = head_dim - rope_dim;
int tid = threadIdx.x;
int n_threads = blockDim.x; // 128
int n_warps = n_threads / 32;
// ---- Step 1: Compute amax over non-RoPE half ----
float local_amax = 0.0f;
for (int i = tid; i < fp8_dim; i += n_threads) {
float v = fabsf(__bfloat162float(entry[b * head_dim + i]));
local_amax = fmaxf(local_amax, v);
}
float block_amax = block_reduce_amax(local_amax, n_warps);
// ---- Step 2: Write inv_scale ----
__shared__ float s_inv_scale;
if (tid == 0) {
float scale = (block_amax > 1e-12f) ? (block_amax / 448.0f) : 1e-12f;
s_inv_scale = scale;
inv_scale[phys_block * entries_per_block + slot_in_block] = scale;
}
__syncthreads();
// ---- Step 3: Quantize and write FP8 half ----
float inv_s = s_inv_scale;
for (int i = tid; i < fp8_dim; i += n_threads) {
float v = __bfloat162float(entry[b * head_dim + i]);
float quantized = v / inv_s;
quantized = fmaxf(-448.0f, fminf(448.0f, quantized));
__nv_fp8_e4m3 fp8_val;
fp8_val = __nv_fp8_e4m3(quantized);
entries_fp8[(phys_block * entries_per_block + slot_in_block) * fp8_dim + i] = fp8_val.__x;
}
// ---- Step 4: Write BF16 RoPE half ----
for (int i = tid; i < rope_dim; i += n_threads) {
entries_rope[(phys_block * entries_per_block + slot_in_block) * rope_dim + i]
= entry[b * head_dim + fp8_dim + i];
}
// ---- Step 5: FP4 quantize and write indexer key ----
// 16 elements per group, one FP8 E4M3 scale per group.
// Process groups in parallel across threads.
int n_groups = indexer_head_dim / 16;
int n_bytes = indexer_head_dim / 2; // 2 FP4 per byte
int n_scales = n_groups;
for (int g = tid; g < n_groups; g += n_threads) {
// Gather 16 BF16 values for this group
__nv_bfloat16 group_in[16];
for (int j = 0; j < 16; j++) {
group_in[j] = indexer_key[b * indexer_head_dim + g * 16 + j];
}
uint8_t group_out[8];
uint8_t group_scale;
quantize_fp4_group(group_in, group_out, &group_scale);
// Write 8 packed bytes
int byte_offset = (phys_block * entries_per_block + slot_in_block) * n_bytes + g * 8;
for (int j = 0; j < 8; j++) {
indexer_keys_fp4[byte_offset + j] = group_out[j];
}
// Write scale
int scale_offset = (phys_block * entries_per_block + slot_in_block) * n_scales + g;
indexer_scale[scale_offset] = group_scale;
}
}
// ===========================================================================
// HCA flush write kernel (no indexer)
// ===========================================================================
__global__ void flush_write_hca_kernel(
const __nv_bfloat16* __restrict__ entry,
const bool* __restrict__ valid_mask,
const int32_t* __restrict__ request_slots,
const int32_t* __restrict__ positions,
const int32_t* __restrict__ block_table,
uint8_t* __restrict__ entries_fp8,
__nv_bfloat16* __restrict__ entries_rope,
float* __restrict__ inv_scale,
int entries_per_block, int m, int rope_dim,
int head_dim, int max_logical_blocks
) {
int b = blockIdx.x;
if (!valid_mask[b]) return;
int pos = positions[b];
int entry_idx = pos / m;
int logical_block = entry_idx / entries_per_block;
int slot_in_block = entry_idx % entries_per_block;
int phys_block = block_table[b * max_logical_blocks + logical_block];
int fp8_dim = head_dim - rope_dim;
int tid = threadIdx.x;
int n_threads = blockDim.x;
int n_warps = n_threads / 32;
// Amax reduction
float local_amax = 0.0f;
for (int i = tid; i < fp8_dim; i += n_threads) {
float v = fabsf(__bfloat162float(entry[b * head_dim + i]));
local_amax = fmaxf(local_amax, v);
}
float block_amax = block_reduce_amax(local_amax, n_warps);
__shared__ float s_inv_scale;
if (tid == 0) {
float scale = (block_amax > 1e-12f) ? (block_amax / 448.0f) : 1e-12f;
s_inv_scale = scale;
inv_scale[phys_block * entries_per_block + slot_in_block] = scale;
}
__syncthreads();
// FP8 quantize + write
float inv_s = s_inv_scale;
for (int i = tid; i < fp8_dim; i += n_threads) {
float v = __bfloat162float(entry[b * head_dim + i]);
float quantized = v / inv_s;
quantized = fmaxf(-448.0f, fminf(448.0f, quantized));
__nv_fp8_e4m3 fp8_val;
fp8_val = __nv_fp8_e4m3(quantized);
entries_fp8[(phys_block * entries_per_block + slot_in_block) * fp8_dim + i] = fp8_val.__x;
}
// BF16 RoPE half
for (int i = tid; i < rope_dim; i += n_threads) {
entries_rope[(phys_block * entries_per_block + slot_in_block) * rope_dim + i]
= entry[b * head_dim + fp8_dim + i];
}
}
// ===========================================================================
// State rotation kernels (in-place, single-kernel launches)
// ===========================================================================
// CSA: after flush, rotate a-stream -> b-stream, clear a-stream
__global__ void csa_rotate_state_kernel(
const bool* __restrict__ valid_mask, // [B]
const int32_t* __restrict__ request_slots, // [B]
// State cache tensors — mutated in place
__nv_bfloat16* __restrict__ tail_ka, // [max_req, m, head_dim]
__nv_bfloat16* __restrict__ tail_za,
__nv_bfloat16* __restrict__ tail_kb,
__nv_bfloat16* __restrict__ tail_zb,
int32_t* __restrict__ tail_len, // [max_req]
int m, int head_dim, int max_requests
) {
int b = blockIdx.x;
if (!valid_mask[b]) return;
int slot = request_slots[b];
int tid = threadIdx.x;
int n_threads = blockDim.x;
// Rotate: kb <- ka, zb <- za (current a-stream becomes next b-stream)
int total = m * head_dim;
for (int i = tid; i < total; i += n_threads) {
tail_kb[slot * total + i] = tail_ka[slot * total + i];
tail_zb[slot * total + i] = tail_za[slot * total + i];
}
// Clear a-stream (zero out) and reset tail_len
if (tid == 0) {
tail_len[slot] = 0;
}
for (int i = tid; i < total; i += n_threads) {
tail_ka[slot * total + i] = __float2bfloat16(0.0f);
tail_za[slot * total + i] = __float2bfloat16(0.0f);
}
}
// HCA: after flush, just clear a-stream and reset tail_len
__global__ void hca_reset_state_kernel(
const bool* __restrict__ valid_mask,
const int32_t* __restrict__ request_slots,
__nv_bfloat16* __restrict__ tail_ka,
__nv_bfloat16* __restrict__ tail_za,
int32_t* __restrict__ tail_len,
int m, int head_dim, int max_requests
) {
int b = blockIdx.x;
if (!valid_mask[b]) return;
int slot = request_slots[b];
int tid = threadIdx.x;
int n_threads = blockDim.x;
int total = m * head_dim;
if (tid == 0) {
tail_len[slot] = 0;
}
for (int i = tid; i < total; i += n_threads) {
tail_ka[slot * total + i] = __float2bfloat16(0.0f);
tail_za[slot * total + i] = __float2bfloat16(0.0f);
}
}
// ===========================================================================
// PyTorch bindings
// ===========================================================================
void flush_write_csa_cuda(
torch::Tensor entry,
torch::Tensor indexer_key,
torch::Tensor valid_mask,
torch::Tensor request_slots,
torch::Tensor positions,
torch::Tensor block_table,
torch::Tensor entries_fp8,
torch::Tensor entries_rope,
torch::Tensor inv_scale,
torch::Tensor indexer_keys_fp4,
torch::Tensor indexer_scale,
int64_t entries_per_block, int64_t m, int64_t rope_dim,
int64_t head_dim, int64_t indexer_head_dim
) {
int B = entry.size(0);
int max_logical_blocks = block_table.size(1);
int threads = 128;
flush_write_csa_kernel<<<B, threads>>>(
reinterpret_cast<const __nv_bfloat16*>(entry.data_ptr<at::BFloat16>()),
reinterpret_cast<const __nv_bfloat16*>(indexer_key.data_ptr<at::BFloat16>()),
valid_mask.data_ptr<bool>(),
request_slots.data_ptr<int32_t>(),
positions.data_ptr<int32_t>(),
block_table.data_ptr<int32_t>(),
entries_fp8.data_ptr<uint8_t>(),
reinterpret_cast<__nv_bfloat16*>(entries_rope.data_ptr<at::BFloat16>()),
inv_scale.data_ptr<float>(),
indexer_keys_fp4.data_ptr<uint8_t>(),
indexer_scale.data_ptr<uint8_t>(),
(int)entries_per_block, (int)m, (int)rope_dim,
(int)head_dim, (int)indexer_head_dim, max_logical_blocks
);
C10_CUDA_CHECK(cudaGetLastError());
}
void flush_write_hca_cuda(
torch::Tensor entry,
torch::Tensor valid_mask,
torch::Tensor request_slots,
torch::Tensor positions,
torch::Tensor block_table,
torch::Tensor entries_fp8,
torch::Tensor entries_rope,
torch::Tensor inv_scale,
int64_t entries_per_block, int64_t m, int64_t rope_dim,
int64_t head_dim
) {
int B = entry.size(0);
int max_logical_blocks = block_table.size(1);
int threads = 128;
flush_write_hca_kernel<<<B, threads>>>(
reinterpret_cast<const __nv_bfloat16*>(entry.data_ptr<at::BFloat16>()),
valid_mask.data_ptr<bool>(),
request_slots.data_ptr<int32_t>(),
positions.data_ptr<int32_t>(),
block_table.data_ptr<int32_t>(),
entries_fp8.data_ptr<uint8_t>(),
reinterpret_cast<__nv_bfloat16*>(entries_rope.data_ptr<at::BFloat16>()),
inv_scale.data_ptr<float>(),
(int)entries_per_block, (int)m, (int)rope_dim,
(int)head_dim, max_logical_blocks
);
C10_CUDA_CHECK(cudaGetLastError());
}
void csa_rotate_state_cuda(
torch::Tensor valid_mask,
torch::Tensor request_slots,
torch::Tensor tail_ka,
torch::Tensor tail_za,
torch::Tensor tail_kb,
torch::Tensor tail_zb,
torch::Tensor tail_len,
int64_t m, int64_t head_dim
) {
int B = valid_mask.size(0);
int threads = 128;
csa_rotate_state_kernel<<<B, threads>>>(
valid_mask.data_ptr<bool>(),
request_slots.data_ptr<int32_t>(),
reinterpret_cast<__nv_bfloat16*>(tail_ka.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(tail_za.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(tail_kb.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(tail_zb.data_ptr<at::BFloat16>()),
tail_len.data_ptr<int32_t>(),
(int)m, (int)head_dim, 0 // max_requests unused in kernel
);
C10_CUDA_CHECK(cudaGetLastError());
}
void hca_reset_state_cuda(
torch::Tensor valid_mask,
torch::Tensor request_slots,
torch::Tensor tail_ka,
torch::Tensor tail_za,
torch::Tensor tail_len,
int64_t m, int64_t head_dim
) {
int B = valid_mask.size(0);
int threads = 128;
hca_reset_state_kernel<<<B, threads>>>(
valid_mask.data_ptr<bool>(),
request_slots.data_ptr<int32_t>(),
reinterpret_cast<__nv_bfloat16*>(tail_ka.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(tail_za.data_ptr<at::BFloat16>()),
tail_len.data_ptr<int32_t>(),
(int)m, (int)head_dim, 0
);
C10_CUDA_CHECK(cudaGetLastError());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("flush_write_csa", &flush_write_csa_cuda, "CSA flush write kernel");
m.def("flush_write_hca", &flush_write_hca_cuda, "HCA flush write kernel");
m.def("csa_rotate_state", &csa_rotate_state_cuda, "CSA state rotation kernel");
m.def("hca_reset_state", &hca_reset_state_cuda, "HCA state reset kernel");
}