Files
nvfp4-megamoe-kernel/dsv4/kernels/compressor/production_compress.py
biondizzle 7ef6402936 KV-1/KV-2/KV-3: NVFP4 compressed KV + FP8 indexer keys
Architecture:
- Compressed KV: stored as NVFP4 (E2M1 + E4M3 + FP32 gsa)
  - Write path: compress→FP32 → FP32 RoPE → quantize FP32→NVFP4
  - Read path: dequant_nvfp4/dequant_nvfp4_selective → BF16 for FMHA
  - No BF16 intermediate in the write path
- Indexer keys: stored as FP8_E4M3 (1 byte + per-row scale)
  - Write path: compress→FP32 → quantize FP32→FP8_E4M3
  - Read path: dequant_fp8_e4m3 → BF16 for scoring
- SWA: remains BF16 (8MB total, fits in L2)

New kernels in kv_quantize.cu:
- compute_amax_gsa_fp32: per-row gsa from FP32 input
- quantize_nvfp4_from_fp32: FP32→NVFP4 with GPU gsa buffer
- quantize_fp8_e4m3_from_fp32: FP32→FP8_E4M3 for indexer keys
- dequant_fp8_e4m3 / dequant_fp8_e4m3_selective: FP8→BF16
- rope_fp32: FP32 GPT-J interleaved RoPE (no BF16)

Proven two-kernel pattern (same as quantize_nvfp4_gpu_fused):
  Kernel 1: amax_gsa (GPU-only)
  Kernel 2: quantize from buffer (GPU gsa)
No shared memory bugs. No cross-CTA race conditions.

KVCache updated:
- comp_kv_fp4/sf/gsa: NVFP4 storage (3.5× smaller than BF16)
- comp_idx_fp8/scale: FP8_E4M3 storage (1.9× smaller than BF16)
- comp_kv property: dequant NVFP4→BF16 on demand
- comp_kv_selective: dequant only top-k entries (bandwidth savings)
- comp_idx_kv property: dequant FP8→BF16 on demand

Removed: compressor_reduce_quant.cu (buggy single-kernel approach)
2026-06-02 10:00:50 +00:00

225 lines
8.2 KiB
Python

"""Production compressor: NVFP4 GEMM projections + CUDA softmax/reduce kernel.
Pipeline:
1. NVFP4 GEMM: hidden_states @ kv_proj → kv (T, kv_dim)
2. NVFP4 GEMM: hidden_states @ gate_proj → gate (T, kv_dim)
3. CUDA kernel: token-level softmax(gate) * kv → compressed entries
4. CUDA kernel: kv_norm (unweighted RMSNorm + weight)
KV-1/KV-2: NVFP4 output variants compress + quantize in a single kernel.
No intermediate BF16. Stored as FP4 data + E4M3 block scales + FP32 global scale.
No PyTorch softmax. No reference fallback. All on the GPU.
"""
from __future__ import annotations
import os
import torch
from typing import Optional
_kernel_module = None
def _get_kernel():
global _kernel_module
if _kernel_module is not None:
return _kernel_module
from torch.utils.cpp_extension import load
kernel_dir = os.path.join(os.path.dirname(__file__), "..", "cuda")
_kernel_module = load(
name="compressor_reduce",
sources=[os.path.join(kernel_dir, "compressor_reduce.cu")],
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
verbose=False,
)
return _kernel_module
def csa_compress_production(
kv_proj_out: torch.Tensor, # (T, 2*hd) FP32 — output of NVFP4 GEMM
gate_proj_out: torch.Tensor, # (T, 2*hd) FP32 — output of NVFP4 GEMM
position_bias: Optional[torch.Tensor], # (m, 2*hd) BF16 or None
kv_norm_weight: Optional[torch.Tensor], # (hd) BF16 or None
m: int = 4,
) -> torch.Tensor:
"""CSA compress: softmax + weighted sum + kv_norm. Returns BF16."""
return csa_compress_production_fp32(
kv_proj_out, gate_proj_out, position_bias, kv_norm_weight, m
).bfloat16()
def csa_compress_production_fp32(
kv_proj_out: torch.Tensor,
gate_proj_out: torch.Tensor,
position_bias: Optional[torch.Tensor],
kv_norm_weight: Optional[torch.Tensor],
m: int = 4,
) -> torch.Tensor:
"""CSA compress: softmax + weighted sum + kv_norm. Returns FP32."""
T = kv_proj_out.shape[0]
hd = kv_proj_out.shape[1] // 2
n_blocks = T // m
if n_blocks == 0:
return torch.zeros(0, hd, dtype=torch.float32, device=kv_proj_out.device)
mod = _get_kernel()
pos_bias_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
if position_bias is not None:
pos_bias_f32 = position_bias.float()
norm_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
if kv_norm_weight is not None:
norm_f32 = kv_norm_weight.float()
compressed = torch.zeros(n_blocks, hd, dtype=torch.float32, device=kv_proj_out.device)
mod.csa_compress_reduce(
kv_proj_out.contiguous(),
gate_proj_out.contiguous(),
pos_bias_f32.contiguous(),
norm_f32.contiguous(),
compressed,
m, n_blocks,
)
return compressed
def hca_compress_production(
kv_proj_out: torch.Tensor, # (T, hd) FP32
gate_proj_out: torch.Tensor, # (T, hd) FP32
position_bias: Optional[torch.Tensor], # (m, hd) BF16 or None
kv_norm_weight: Optional[torch.Tensor], # (hd) BF16 or None
m: int = 128,
) -> torch.Tensor:
"""HCA compress: softmax + weighted sum + kv_norm. Returns BF16."""
return hca_compress_production_fp32(
kv_proj_out, gate_proj_out, position_bias, kv_norm_weight, m
).bfloat16()
def hca_compress_production_fp32(
kv_proj_out: torch.Tensor,
gate_proj_out: torch.Tensor,
position_bias: Optional[torch.Tensor],
kv_norm_weight: Optional[torch.Tensor],
m: int = 128,
) -> torch.Tensor:
"""HCA compress: softmax + weighted sum + kv_norm. Returns FP32."""
T = kv_proj_out.shape[0]
hd = kv_proj_out.shape[1]
n_blocks = T // m
if n_blocks == 0:
return torch.zeros(0, hd, dtype=torch.float32, device=kv_proj_out.device)
mod = _get_kernel()
pos_bias_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
if position_bias is not None:
pos_bias_f32 = position_bias.float()
norm_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
if kv_norm_weight is not None:
norm_f32 = kv_norm_weight.float()
compressed = torch.zeros(n_blocks, hd, dtype=torch.float32, device=kv_proj_out.device)
mod.hca_compress_reduce(
kv_proj_out.contiguous(),
gate_proj_out.contiguous(),
pos_bias_f32.contiguous(),
norm_f32.contiguous(),
compressed,
m, n_blocks,
)
return compressed
# ===========================================================================
# KV-1/KV-2: NVFP4 output — two proven kernels, no BF16 intermediate
#
# Architecture:
# 1. CUDA compress kernel (compressor_reduce.cu) → FP32 compressed output
# 2. CUDA amax_gsa_fp32 → per-row gsa (GPU-only, no CPU sync)
# 3. CUDA quantize_nvfp4_from_fp32 → NVFP4 triple (fp4 + sf + gsa)
#
# This is the same two-kernel pattern that works everywhere else in the
# pipeline (quantize_nvfp4_gpu_fused). The previous single-kernel fused
# approach had shared memory corruption bugs. Two kernels is correct.
#
# Storage: NVFP4 (E2M1 data + E4M3 block scales + FP32 global scale)
# Read path: dequant_nvfp4 / dequant_nvfp4_selective → BF16 for FMHA
# ===========================================================================
def _quantize_fp32_to_nvfp4(compressed_fp32: torch.Tensor) -> tuple:
"""Quantize FP32 compressed output → NVFP4. Two-kernel, GPU-only.
Uses the same proven pattern as quantize_nvfp4_gpu_fused (amax_gsa +
quantize_from_buffer) but with FP32 input instead of BF16.
No BF16 intermediate. No CPU sync.
Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple.
"""
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
# Kernel 1: Compute per-row gsa from FP32 input (GPU-only)
gsa = mod.compute_amax_gsa_fp32(compressed_fp32.contiguous(), 6.0 * 448.0)
# Kernel 2: Quantize FP32 → NVFP4 using GPU gsa buffer
fp4, sf = mod.quantize_nvfp4_from_fp32(compressed_fp32.contiguous(), gsa)
return fp4, sf, gsa
def csa_compress_production_nvfp4(
kv_proj_out: torch.Tensor,
gate_proj_out: torch.Tensor,
position_bias: Optional[torch.Tensor],
kv_norm_weight: Optional[torch.Tensor],
m: int = 4,
) -> tuple:
"""CSA compress → NVFP4. No BF16 intermediate.
KV-1: Production path. Compressed KV stored as NVFP4.
Pipeline: compress (FP32) → amax_gsa (GPU) → quantize (GPU) → NVFP4 triple.
Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple.
"""
# Step 1: Compress → FP32 (same proven kernel as BF16 path)
compressed_fp32 = csa_compress_production_fp32(
kv_proj_out, gate_proj_out, position_bias, kv_norm_weight, m)
if compressed_fp32.shape[0] == 0:
dev = kv_proj_out.device
hd = kv_proj_out.shape[1] // 2
return (torch.zeros(0, hd // 2, dtype=torch.float4_e2m1fn_x2, device=dev),
torch.zeros(0, hd // 16, dtype=torch.float8_e4m3fn, device=dev),
torch.zeros(0, dtype=torch.float32, device=dev))
# Step 2-3: FP32 → NVFP4 (two proven kernels)
return _quantize_fp32_to_nvfp4(compressed_fp32)
def hca_compress_production_nvfp4(
kv_proj_out: torch.Tensor,
gate_proj_out: torch.Tensor,
position_bias: Optional[torch.Tensor],
kv_norm_weight: Optional[torch.Tensor],
m: int = 128,
) -> tuple:
"""HCA compress → NVFP4. No BF16 intermediate.
KV-2: Production path. Compressed KV stored as NVFP4.
Pipeline: compress (FP32) → amax_gsa (GPU) → quantize (GPU) → NVFP4 triple.
Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple.
"""
# Step 1: Compress → FP32
compressed_fp32 = hca_compress_production_fp32(
kv_proj_out, gate_proj_out, position_bias, kv_norm_weight, m)
if compressed_fp32.shape[0] == 0:
dev = kv_proj_out.device
hd = kv_proj_out.shape[1]
return (torch.zeros(0, hd // 2, dtype=torch.float4_e2m1fn_x2, device=dev),
torch.zeros(0, hd // 16, dtype=torch.float8_e4m3fn, device=dev),
torch.zeros(0, dtype=torch.float32, device=dev))
# Step 2-3: FP32 → NVFP4
return _quantize_fp32_to_nvfp4(compressed_fp32)