Bug #2 fix: warmup_compilation and warmup_fused_swiglu_compilation now use valid FP4 data by quantizing random BF16 through quantize_to_nvfp4. Random uint8 bytes as FP4 bit patterns cause cudaErrorIllegalInstruction in Blackwell MMA hardware. Re-enabled warmup calls in runner.py. Bug #1 kernel: sparse_topk_metadata.cu with: - build_c128a_topk_metadata: position-based compressed KV slot lookup via block table for C128A (compress_ratio=128) decode tokens - compute_c4a_global_topk: local topk index -> global slot ID mapping via block table for C4A (compress_ratio=4) decode tokens - Both tested: correct block table lookups, proper padding Bug #3 kernel: C4A uses compute_c4a_global_topk (same .cu file) - Replaces vLLM Triton kernel with our own CUDA kernel Deleted stale STATUS.md, FUSED_EPILOGUE_STATUS.md, FUSED_EPILOGUE_PLAN.md, CURRENT_BUGMD
90 lines
2.5 KiB
Python
90 lines
2.5 KiB
Python
"""
|
|
Sparse topk metadata kernels for DeepSeek-V4 Blackwell decode attention.
|
|
|
|
Own kernels — no FlashMLA, no Triton from vLLM.
|
|
|
|
C128A: position-based compressed KV slot lookup via block table.
|
|
C4A: local topk index to global slot ID mapping via block table.
|
|
"""
|
|
|
|
import os
|
|
import torch
|
|
from typing import Optional
|
|
|
|
_kernel_module = None
|
|
|
|
def _get_kernel_module():
|
|
"""Lazy-load the CUDA extension."""
|
|
global _kernel_module
|
|
if _kernel_module is not None:
|
|
return _kernel_module
|
|
|
|
from torch.utils.cpp_extension import load
|
|
kernel_dir = os.path.join(os.path.dirname(__file__), "kernels")
|
|
_kernel_module = load(
|
|
name="sparse_topk_metadata",
|
|
sources=[os.path.join(kernel_dir, "sparse_topk_metadata.cu")],
|
|
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
|
|
verbose=False,
|
|
)
|
|
return _kernel_module
|
|
|
|
|
|
def build_c128a_topk_metadata(
|
|
positions: torch.Tensor,
|
|
compress_ratio: int,
|
|
num_decode_tokens: int,
|
|
token_to_req: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
block_size: int,
|
|
slot_mapping: torch.Tensor,
|
|
global_decode_buffer: torch.Tensor,
|
|
decode_lens_buffer: torch.Tensor,
|
|
prefill_buffer: torch.Tensor,
|
|
max_compressed_tokens: int = 8192,
|
|
) -> tuple:
|
|
"""Build C128A topk metadata for decode and prefill tokens.
|
|
|
|
For decode tokens: maps compressed KV positions to global slot IDs
|
|
via block table lookup. Returns (global_decode, decode_lens, prefill_local).
|
|
"""
|
|
mod = _get_kernel_module()
|
|
return mod.build_c128a_topk_metadata(
|
|
positions,
|
|
compress_ratio,
|
|
num_decode_tokens,
|
|
token_to_req,
|
|
block_table,
|
|
block_size,
|
|
slot_mapping,
|
|
global_decode_buffer,
|
|
decode_lens_buffer,
|
|
prefill_buffer,
|
|
max_compressed_tokens,
|
|
)
|
|
|
|
|
|
def compute_c4a_global_topk(
|
|
local_topk: torch.Tensor,
|
|
token_to_req: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
block_size: int,
|
|
is_valid_token: torch.Tensor,
|
|
) -> tuple:
|
|
"""Map local C4A topk indices to global KV cache slots.
|
|
|
|
For each token, takes local compressed indices (from the indexer)
|
|
and maps them to global slot IDs via block table lookup.
|
|
Returns (global_topk_indices, topk_lens).
|
|
"""
|
|
mod = _get_kernel_module()
|
|
if is_valid_token.dtype == torch.bool:
|
|
is_valid_token = is_valid_token.to(torch.int32)
|
|
return mod.compute_c4a_global_topk(
|
|
local_topk,
|
|
token_to_req,
|
|
block_table,
|
|
block_size,
|
|
is_valid_token,
|
|
)
|