Files
nvfp4-megamoe-kernel/cutedsl/sparse_topk_metadata.py
biondizzle 67d5e26080 Fix warmup compilation + add sparse topk metadata kernels
Bug #2 fix: warmup_compilation and warmup_fused_swiglu_compilation now
use valid FP4 data by quantizing random BF16 through quantize_to_nvfp4.
Random uint8 bytes as FP4 bit patterns cause cudaErrorIllegalInstruction
in Blackwell MMA hardware. Re-enabled warmup calls in runner.py.

Bug #1 kernel: sparse_topk_metadata.cu with:
  - build_c128a_topk_metadata: position-based compressed KV slot lookup
    via block table for C128A (compress_ratio=128) decode tokens
  - compute_c4a_global_topk: local topk index -> global slot ID mapping
    via block table for C4A (compress_ratio=4) decode tokens
  - Both tested: correct block table lookups, proper padding

Bug #3 kernel: C4A uses compute_c4a_global_topk (same .cu file)
  - Replaces vLLM Triton kernel with our own CUDA kernel

Deleted stale STATUS.md, FUSED_EPILOGUE_STATUS.md, FUSED_EPILOGUE_PLAN.md, CURRENT_BUGMD
2026-05-20 06:43:43 +00:00

90 lines
2.5 KiB
Python

"""
Sparse topk metadata kernels for DeepSeek-V4 Blackwell decode attention.
Own kernels — no FlashMLA, no Triton from vLLM.
C128A: position-based compressed KV slot lookup via block table.
C4A: local topk index to global slot ID mapping via block table.
"""
import os
import torch
from typing import Optional
_kernel_module = None
def _get_kernel_module():
"""Lazy-load the CUDA extension."""
global _kernel_module
if _kernel_module is not None:
return _kernel_module
from torch.utils.cpp_extension import load
kernel_dir = os.path.join(os.path.dirname(__file__), "kernels")
_kernel_module = load(
name="sparse_topk_metadata",
sources=[os.path.join(kernel_dir, "sparse_topk_metadata.cu")],
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
verbose=False,
)
return _kernel_module
def build_c128a_topk_metadata(
positions: torch.Tensor,
compress_ratio: int,
num_decode_tokens: int,
token_to_req: torch.Tensor,
block_table: torch.Tensor,
block_size: int,
slot_mapping: torch.Tensor,
global_decode_buffer: torch.Tensor,
decode_lens_buffer: torch.Tensor,
prefill_buffer: torch.Tensor,
max_compressed_tokens: int = 8192,
) -> tuple:
"""Build C128A topk metadata for decode and prefill tokens.
For decode tokens: maps compressed KV positions to global slot IDs
via block table lookup. Returns (global_decode, decode_lens, prefill_local).
"""
mod = _get_kernel_module()
return mod.build_c128a_topk_metadata(
positions,
compress_ratio,
num_decode_tokens,
token_to_req,
block_table,
block_size,
slot_mapping,
global_decode_buffer,
decode_lens_buffer,
prefill_buffer,
max_compressed_tokens,
)
def compute_c4a_global_topk(
local_topk: torch.Tensor,
token_to_req: torch.Tensor,
block_table: torch.Tensor,
block_size: int,
is_valid_token: torch.Tensor,
) -> tuple:
"""Map local C4A topk indices to global KV cache slots.
For each token, takes local compressed indices (from the indexer)
and maps them to global slot IDs via block table lookup.
Returns (global_topk_indices, topk_lens).
"""
mod = _get_kernel_module()
if is_valid_token.dtype == torch.bool:
is_valid_token = is_valid_token.to(torch.int32)
return mod.compute_c4a_global_topk(
local_topk,
token_to_req,
block_table,
block_size,
is_valid_token,
)