Files
nvfp4-megamoe-kernel/dsv4/kernels/compressor/production_compress.py

133 lines
4.2 KiB
Python

"""Production compressor: NVFP4 GEMM projections + CUDA softmax/reduce kernel.
Pipeline:
1. NVFP4 GEMM: hidden_states @ kv_proj → kv (T, kv_dim)
2. NVFP4 GEMM: hidden_states @ gate_proj → gate (T, kv_dim)
3. CUDA kernel: token-level softmax(gate) * kv → compressed entries
4. CUDA kernel: kv_norm (unweighted RMSNorm + weight)
No PyTorch softmax. No reference fallback. All on the GPU.
"""
from __future__ import annotations
import os
import torch
from typing import Optional
_kernel_module = None
def _get_kernel():
global _kernel_module
if _kernel_module is not None:
return _kernel_module
from torch.utils.cpp_extension import load
kernel_dir = os.path.join(os.path.dirname(__file__), "..", "cuda")
_kernel_module = load(
name="compressor_reduce",
sources=[os.path.join(kernel_dir, "compressor_reduce.cu")],
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
verbose=False,
)
return _kernel_module
def csa_compress_production(
kv_proj_out: torch.Tensor, # (T, 2*hd) FP32 — output of NVFP4 GEMM
gate_proj_out: torch.Tensor, # (T, 2*hd) FP32 — output of NVFP4 GEMM
position_bias: Optional[torch.Tensor], # (m, 2*hd) BF16 or None
kv_norm_weight: Optional[torch.Tensor], # (hd) BF16 or None
m: int = 4,
) -> torch.Tensor:
"""CSA compress: softmax + weighted sum + kv_norm.
Args:
kv_proj_out: FP32 projection output, (T, 2*hd), Ca in first hd cols, Cb in second
gate_proj_out: FP32 projection output, (T, 2*hd), Ga in first hd cols, Gb in second
position_bias: (m, 2*hd) BF16 position bias, or None
kv_norm_weight: (hd) BF16 norm weight, or None
m: compression ratio (4 for CSA)
Returns:
compressed: (n_blocks, hd) BF16
"""
T = kv_proj_out.shape[0]
hd = kv_proj_out.shape[1] // 2
n_blocks = T // m
if n_blocks == 0:
return torch.zeros(0, hd, dtype=torch.bfloat16, device=kv_proj_out.device)
mod = _get_kernel()
# Convert position_bias and kv_norm_weight to FP32
pos_bias_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
if position_bias is not None:
pos_bias_f32 = position_bias.float()
norm_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
if kv_norm_weight is not None:
norm_f32 = kv_norm_weight.float()
compressed = torch.zeros(n_blocks, hd, dtype=torch.float32, device=kv_proj_out.device)
mod.csa_compress_reduce(
kv_proj_out.contiguous(),
gate_proj_out.contiguous(),
pos_bias_f32.contiguous(),
norm_f32.contiguous(),
compressed,
m, n_blocks,
)
return compressed.bfloat16()
def hca_compress_production(
kv_proj_out: torch.Tensor, # (T, hd) FP32
gate_proj_out: torch.Tensor, # (T, hd) FP32
position_bias: Optional[torch.Tensor], # (m, hd) BF16 or None
kv_norm_weight: Optional[torch.Tensor], # (hd) BF16 or None
m: int = 128,
) -> torch.Tensor:
"""HCA compress: softmax + weighted sum + kv_norm.
Args:
kv_proj_out: FP32 projection output, (T, hd)
gate_proj_out: FP32 projection output, (T, hd)
position_bias: (m, hd) BF16 position bias, or None
kv_norm_weight: (hd) BF16 norm weight, or None
m: compression ratio (128 for HCA)
Returns:
compressed: (n_blocks, hd) BF16
"""
T = kv_proj_out.shape[0]
hd = kv_proj_out.shape[1]
n_blocks = T // m
if n_blocks == 0:
return torch.zeros(0, hd, dtype=torch.bfloat16, device=kv_proj_out.device)
mod = _get_kernel()
pos_bias_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
if position_bias is not None:
pos_bias_f32 = position_bias.float()
norm_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
if kv_norm_weight is not None:
norm_f32 = kv_norm_weight.float()
compressed = torch.zeros(n_blocks, hd, dtype=torch.float32, device=kv_proj_out.device)
mod.hca_compress_reduce(
kv_proj_out.contiguous(),
gate_proj_out.contiguous(),
pos_bias_f32.contiguous(),
norm_f32.contiguous(),
compressed,
m, n_blocks,
)
return compressed.bfloat16()