Files
nvfp4-megamoe-kernel/dsv4/ops/topk.py
biondizzle 3fb3c925af Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

90 lines
2.5 KiB
Python

"""
Sparse topk metadata kernels for DeepSeek-V4 Blackwell decode attention.
Own kernels — no FlashMLA, no Triton from vLLM.
C128A: position-based compressed KV slot lookup via block table.
C4A: local topk index to global slot ID mapping via block table.
"""
import os
import torch
from typing import Optional
_kernel_module = None
def _get_kernel_module():
"""Lazy-load the CUDA extension."""
global _kernel_module
if _kernel_module is not None:
return _kernel_module
from torch.utils.cpp_extension import load
kernel_dir = os.path.join(os.path.dirname(__file__), "kernels")
_kernel_module = load(
name="sparse_topk_metadata",
sources=[os.path.join(kernel_dir, "sparse_topk_metadata.cu")],
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
verbose=False,
)
return _kernel_module
def build_c128a_topk_metadata(
positions: torch.Tensor,
compress_ratio: int,
num_decode_tokens: int,
token_to_req: torch.Tensor,
block_table: torch.Tensor,
block_size: int,
slot_mapping: torch.Tensor,
global_decode_buffer: torch.Tensor,
decode_lens_buffer: torch.Tensor,
prefill_buffer: torch.Tensor,
max_compressed_tokens: int = 8192,
) -> tuple:
"""Build C128A topk metadata for decode and prefill tokens.
For decode tokens: maps compressed KV positions to global slot IDs
via block table lookup. Returns (global_decode, decode_lens, prefill_local).
"""
mod = _get_kernel_module()
return mod.build_c128a_topk_metadata(
positions,
compress_ratio,
num_decode_tokens,
token_to_req,
block_table,
block_size,
slot_mapping,
global_decode_buffer,
decode_lens_buffer,
prefill_buffer,
max_compressed_tokens,
)
def compute_c4a_global_topk(
local_topk: torch.Tensor,
token_to_req: torch.Tensor,
block_table: torch.Tensor,
block_size: int,
is_valid_token: torch.Tensor,
) -> tuple:
"""Map local C4A topk indices to global KV cache slots.
For each token, takes local compressed indices (from the indexer)
and maps them to global slot IDs via block table lookup.
Returns (global_topk_indices, topk_lens).
"""
mod = _get_kernel_module()
if is_valid_token.dtype == torch.bool:
is_valid_token = is_valid_token.to(torch.int32)
return mod.compute_c4a_global_topk(
local_topk,
token_to_req,
block_table,
block_size,
is_valid_token,
)