- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
90 lines
2.5 KiB
Python
90 lines
2.5 KiB
Python
"""
|
|
Sparse topk metadata kernels for DeepSeek-V4 Blackwell decode attention.
|
|
|
|
Own kernels — no FlashMLA, no Triton from vLLM.
|
|
|
|
C128A: position-based compressed KV slot lookup via block table.
|
|
C4A: local topk index to global slot ID mapping via block table.
|
|
"""
|
|
|
|
import os
|
|
import torch
|
|
from typing import Optional
|
|
|
|
_kernel_module = None
|
|
|
|
def _get_kernel_module():
|
|
"""Lazy-load the CUDA extension."""
|
|
global _kernel_module
|
|
if _kernel_module is not None:
|
|
return _kernel_module
|
|
|
|
from torch.utils.cpp_extension import load
|
|
kernel_dir = os.path.join(os.path.dirname(__file__), "kernels")
|
|
_kernel_module = load(
|
|
name="sparse_topk_metadata",
|
|
sources=[os.path.join(kernel_dir, "sparse_topk_metadata.cu")],
|
|
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
|
|
verbose=False,
|
|
)
|
|
return _kernel_module
|
|
|
|
|
|
def build_c128a_topk_metadata(
|
|
positions: torch.Tensor,
|
|
compress_ratio: int,
|
|
num_decode_tokens: int,
|
|
token_to_req: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
block_size: int,
|
|
slot_mapping: torch.Tensor,
|
|
global_decode_buffer: torch.Tensor,
|
|
decode_lens_buffer: torch.Tensor,
|
|
prefill_buffer: torch.Tensor,
|
|
max_compressed_tokens: int = 8192,
|
|
) -> tuple:
|
|
"""Build C128A topk metadata for decode and prefill tokens.
|
|
|
|
For decode tokens: maps compressed KV positions to global slot IDs
|
|
via block table lookup. Returns (global_decode, decode_lens, prefill_local).
|
|
"""
|
|
mod = _get_kernel_module()
|
|
return mod.build_c128a_topk_metadata(
|
|
positions,
|
|
compress_ratio,
|
|
num_decode_tokens,
|
|
token_to_req,
|
|
block_table,
|
|
block_size,
|
|
slot_mapping,
|
|
global_decode_buffer,
|
|
decode_lens_buffer,
|
|
prefill_buffer,
|
|
max_compressed_tokens,
|
|
)
|
|
|
|
|
|
def compute_c4a_global_topk(
|
|
local_topk: torch.Tensor,
|
|
token_to_req: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
block_size: int,
|
|
is_valid_token: torch.Tensor,
|
|
) -> tuple:
|
|
"""Map local C4A topk indices to global KV cache slots.
|
|
|
|
For each token, takes local compressed indices (from the indexer)
|
|
and maps them to global slot IDs via block table lookup.
|
|
Returns (global_topk_indices, topk_lens).
|
|
"""
|
|
mod = _get_kernel_module()
|
|
if is_valid_token.dtype == torch.bool:
|
|
is_valid_token = is_valid_token.to(torch.int32)
|
|
return mod.compute_c4a_global_topk(
|
|
local_topk,
|
|
token_to_req,
|
|
block_table,
|
|
block_size,
|
|
is_valid_token,
|
|
)
|