133 lines
4.2 KiB
Python
133 lines
4.2 KiB
Python
"""Production compressor: NVFP4 GEMM projections + CUDA softmax/reduce kernel.
|
|
|
|
Pipeline:
|
|
1. NVFP4 GEMM: hidden_states @ kv_proj → kv (T, kv_dim)
|
|
2. NVFP4 GEMM: hidden_states @ gate_proj → gate (T, kv_dim)
|
|
3. CUDA kernel: token-level softmax(gate) * kv → compressed entries
|
|
4. CUDA kernel: kv_norm (unweighted RMSNorm + weight)
|
|
|
|
No PyTorch softmax. No reference fallback. All on the GPU.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import torch
|
|
from typing import Optional
|
|
|
|
_kernel_module = None
|
|
|
|
|
|
def _get_kernel():
|
|
global _kernel_module
|
|
if _kernel_module is not None:
|
|
return _kernel_module
|
|
from torch.utils.cpp_extension import load
|
|
kernel_dir = os.path.join(os.path.dirname(__file__), "..", "cuda")
|
|
_kernel_module = load(
|
|
name="compressor_reduce",
|
|
sources=[os.path.join(kernel_dir, "compressor_reduce.cu")],
|
|
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
|
|
verbose=False,
|
|
)
|
|
return _kernel_module
|
|
|
|
|
|
def csa_compress_production(
|
|
kv_proj_out: torch.Tensor, # (T, 2*hd) FP32 — output of NVFP4 GEMM
|
|
gate_proj_out: torch.Tensor, # (T, 2*hd) FP32 — output of NVFP4 GEMM
|
|
position_bias: Optional[torch.Tensor], # (m, 2*hd) BF16 or None
|
|
kv_norm_weight: Optional[torch.Tensor], # (hd) BF16 or None
|
|
m: int = 4,
|
|
) -> torch.Tensor:
|
|
"""CSA compress: softmax + weighted sum + kv_norm.
|
|
|
|
Args:
|
|
kv_proj_out: FP32 projection output, (T, 2*hd), Ca in first hd cols, Cb in second
|
|
gate_proj_out: FP32 projection output, (T, 2*hd), Ga in first hd cols, Gb in second
|
|
position_bias: (m, 2*hd) BF16 position bias, or None
|
|
kv_norm_weight: (hd) BF16 norm weight, or None
|
|
m: compression ratio (4 for CSA)
|
|
|
|
Returns:
|
|
compressed: (n_blocks, hd) BF16
|
|
"""
|
|
T = kv_proj_out.shape[0]
|
|
hd = kv_proj_out.shape[1] // 2
|
|
n_blocks = T // m
|
|
if n_blocks == 0:
|
|
return torch.zeros(0, hd, dtype=torch.bfloat16, device=kv_proj_out.device)
|
|
|
|
mod = _get_kernel()
|
|
|
|
# Convert position_bias and kv_norm_weight to FP32
|
|
pos_bias_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
|
|
if position_bias is not None:
|
|
pos_bias_f32 = position_bias.float()
|
|
|
|
norm_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
|
|
if kv_norm_weight is not None:
|
|
norm_f32 = kv_norm_weight.float()
|
|
|
|
compressed = torch.zeros(n_blocks, hd, dtype=torch.float32, device=kv_proj_out.device)
|
|
|
|
mod.csa_compress_reduce(
|
|
kv_proj_out.contiguous(),
|
|
gate_proj_out.contiguous(),
|
|
pos_bias_f32.contiguous(),
|
|
norm_f32.contiguous(),
|
|
compressed,
|
|
m, n_blocks,
|
|
)
|
|
|
|
return compressed.bfloat16()
|
|
|
|
|
|
def hca_compress_production(
|
|
kv_proj_out: torch.Tensor, # (T, hd) FP32
|
|
gate_proj_out: torch.Tensor, # (T, hd) FP32
|
|
position_bias: Optional[torch.Tensor], # (m, hd) BF16 or None
|
|
kv_norm_weight: Optional[torch.Tensor], # (hd) BF16 or None
|
|
m: int = 128,
|
|
) -> torch.Tensor:
|
|
"""HCA compress: softmax + weighted sum + kv_norm.
|
|
|
|
Args:
|
|
kv_proj_out: FP32 projection output, (T, hd)
|
|
gate_proj_out: FP32 projection output, (T, hd)
|
|
position_bias: (m, hd) BF16 position bias, or None
|
|
kv_norm_weight: (hd) BF16 norm weight, or None
|
|
m: compression ratio (128 for HCA)
|
|
|
|
Returns:
|
|
compressed: (n_blocks, hd) BF16
|
|
"""
|
|
T = kv_proj_out.shape[0]
|
|
hd = kv_proj_out.shape[1]
|
|
n_blocks = T // m
|
|
if n_blocks == 0:
|
|
return torch.zeros(0, hd, dtype=torch.bfloat16, device=kv_proj_out.device)
|
|
|
|
mod = _get_kernel()
|
|
|
|
pos_bias_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
|
|
if position_bias is not None:
|
|
pos_bias_f32 = position_bias.float()
|
|
|
|
norm_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
|
|
if kv_norm_weight is not None:
|
|
norm_f32 = kv_norm_weight.float()
|
|
|
|
compressed = torch.zeros(n_blocks, hd, dtype=torch.float32, device=kv_proj_out.device)
|
|
|
|
mod.hca_compress_reduce(
|
|
kv_proj_out.contiguous(),
|
|
gate_proj_out.contiguous(),
|
|
pos_bias_f32.contiguous(),
|
|
norm_f32.contiguous(),
|
|
compressed,
|
|
m, n_blocks,
|
|
)
|
|
|
|
return compressed.bfloat16()
|