Files
nvfp4-megamoe-kernel/dsv4/kernels/cuda/sparse_topk_metadata.cu
biondizzle 9cbdc92744 Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

216 lines
9.1 KiB
Plaintext

#include <cuda.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <cstdint>
// ============================================================================
// C128A topk metadata kernel
// ============================================================================
// For C128A (compress_ratio=128) decode tokens:
// - position -> num_compressed = (position + 1) / compress_ratio
// - For each compressed KV slot [0, num_compressed):
// block_index = i / block_size
// block_offset = i % block_size
// global_slot = block_table[req_idx, block_index] * block_size + block_offset
// - Output: global_slot IDs in out_indices, count in out_lens
// - Invalid tokens (slot_mapping < 0) get length 0
//
// For prefill tokens:
// - Output: local indices [0, 1, ..., num_compressed-1, -1, -1, ...]
// ============================================================================
__global__ void build_c128a_topk_metadata_kernel(
// Decode outputs
int32_t* __restrict__ global_decode_ptr, // [num_decode_tokens, max_compressed]
int64_t global_decode_stride, // stride in elements
int32_t* __restrict__ decode_lens_ptr, // [num_decode_tokens]
// Prefill output
int32_t* __restrict__ prefill_local_ptr, // [num_prefill_tokens, max_compressed]
int64_t prefill_local_stride,
// Inputs
const int64_t* __restrict__ positions_ptr, // [num_tokens]
int32_t compress_ratio,
int32_t max_compressed_tokens,
int32_t num_decode_tokens,
const int32_t* __restrict__ token_to_req_ptr, // [num_tokens]
const int32_t* __restrict__ block_table_ptr, // [num_reqs, max_blocks_per_seq]
int64_t block_table_stride, // stride in elements
int32_t block_size,
const int64_t* __restrict__ slot_mapping_ptr // [num_tokens]
) {
int token_idx = blockIdx.x;
int64_t position = positions_ptr[token_idx];
int32_t num_compressed = static_cast<int32_t>((position + 1) / compress_ratio);
if (num_compressed > max_compressed_tokens)
num_compressed = max_compressed_tokens;
bool is_decode = token_idx < num_decode_tokens;
if (is_decode) {
// Decode: block-table lookup -> global slot ids + count
int64_t slot = slot_mapping_ptr[token_idx];
bool is_valid = slot >= 0;
int32_t req_idx = token_to_req_ptr[token_idx];
int32_t count = 0;
for (int32_t i = 0; i < max_compressed_tokens; i++) {
int64_t out_offset = static_cast<int64_t>(token_idx) * global_decode_stride + i;
if (i < num_compressed) {
int32_t block_index = i / block_size;
int32_t block_offset = i % block_size;
int64_t bt_offset = static_cast<int64_t>(req_idx) * block_table_stride + block_index;
int32_t block_number = block_table_ptr[bt_offset];
int32_t slot_id = block_number * block_size + block_offset;
global_decode_ptr[out_offset] = slot_id;
count++;
} else {
global_decode_ptr[out_offset] = -1;
}
}
decode_lens_ptr[token_idx] = is_valid ? count : 0;
} else {
// Prefill: write local indices [0, 1, ..., n-1, -1, ...]
int32_t pfx_idx = token_idx - num_decode_tokens;
for (int32_t i = 0; i < max_compressed_tokens; i++) {
int64_t out_offset = static_cast<int64_t>(pfx_idx) * prefill_local_stride + i;
prefill_local_ptr[out_offset] = (i < num_compressed) ? i : -1;
}
}
}
// ============================================================================
// C4A topk metadata kernel
// ============================================================================
// For C4A (compress_ratio=4) decode tokens:
// - topk_indices: local compressed indices from the indexer
// - Map each local index to a global KV cache slot via block table lookup
// - Count valid entries (local_idx >= 0)
// - Invalid tokens get length 0
// ============================================================================
__global__ void compute_c4a_global_topk_kernel(
// Outputs
int32_t* __restrict__ global_topk_ptr, // [num_tokens, topk_dim]
int64_t global_topk_stride, // stride in elements
int32_t* __restrict__ topk_lens_ptr, // [num_tokens]
// Inputs
const int32_t* __restrict__ local_topk_ptr, // [num_tokens, topk_dim]
int64_t local_topk_stride, // stride in elements
int32_t topk_dim,
const int32_t* __restrict__ token_to_req_ptr, // [num_tokens]
const int32_t* __restrict__ block_table_ptr, // [num_reqs, max_blocks_per_seq]
int64_t block_table_stride,
int32_t block_size,
const int32_t* __restrict__ is_valid_token_ptr // [num_tokens] boolean as int32
) {
int token_idx = blockIdx.x;
int32_t is_valid = is_valid_token_ptr[token_idx];
int32_t req_idx = token_to_req_ptr[token_idx];
int32_t count = 0;
for (int32_t i = 0; i < topk_dim; i++) {
int64_t in_offset = static_cast<int64_t>(token_idx) * local_topk_stride + i;
int32_t local_idx = local_topk_ptr[in_offset];
int64_t out_offset = static_cast<int64_t>(token_idx) * global_topk_stride + i;
if (local_idx >= 0) {
int32_t block_index = local_idx / block_size;
int32_t block_offset = local_idx % block_size;
int64_t bt_offset = static_cast<int64_t>(req_idx) * block_table_stride + block_index;
int32_t block_number = block_table_ptr[bt_offset];
int32_t slot_id = block_number * block_size + block_offset;
global_topk_ptr[out_offset] = slot_id;
count++;
} else {
global_topk_ptr[out_offset] = -1;
}
}
topk_lens_ptr[token_idx] = is_valid ? count : 0;
}
// ============================================================================
// Python bindings
// ============================================================================
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> build_c128a_topk_metadata_cuda(
torch::Tensor positions, // [num_tokens] int64
int32_t compress_ratio,
int32_t num_decode_tokens,
torch::Tensor token_to_req, // [num_tokens] int32
torch::Tensor block_table, // [num_reqs, max_blocks] int32
int32_t block_size,
torch::Tensor slot_mapping, // [num_tokens] int64
torch::Tensor global_decode_buffer, // [max_tokens, max_compressed] int32
torch::Tensor decode_lens_buffer, // [max_tokens] int32
torch::Tensor prefill_buffer, // [max_tokens, max_compressed] int32
int32_t max_compressed_tokens
) {
int32_t num_tokens = positions.size(0);
int32_t num_prefill_tokens = num_tokens - num_decode_tokens;
auto global_decode = global_decode_buffer.narrow(0, 0, num_decode_tokens);
auto decode_lens = decode_lens_buffer.narrow(0, 0, num_decode_tokens);
auto prefill_local = prefill_buffer.narrow(0, 0, num_prefill_tokens);
if (num_tokens == 0) {
return std::make_tuple(global_decode, decode_lens, prefill_local);
}
build_c128a_topk_metadata_kernel<<<num_tokens, 1>>>(
global_decode_buffer.data_ptr<int32_t>(),
global_decode_buffer.stride(0),
decode_lens_buffer.data_ptr<int32_t>(),
prefill_buffer.data_ptr<int32_t>(),
prefill_buffer.stride(0),
positions.data_ptr<int64_t>(),
compress_ratio,
max_compressed_tokens,
num_decode_tokens,
token_to_req.data_ptr<int32_t>(),
block_table.data_ptr<int32_t>(),
block_table.stride(0),
block_size,
slot_mapping.data_ptr<int64_t>()
);
return std::make_tuple(global_decode, decode_lens, prefill_local);
}
std::tuple<torch::Tensor, torch::Tensor> compute_c4a_global_topk_cuda(
torch::Tensor local_topk, // [num_tokens, topk_dim] int32
torch::Tensor token_to_req, // [num_tokens] int32
torch::Tensor block_table, // [num_reqs, max_blocks] int32
int32_t block_size,
torch::Tensor is_valid_token // [num_tokens] bool (stored as int32)
) {
int32_t num_tokens = local_topk.size(0);
int32_t topk_dim = local_topk.size(1);
auto global_topk = torch::empty_like(local_topk);
auto topk_lens = torch::empty(num_tokens, local_topk.options().dtype(torch::kInt32));
compute_c4a_global_topk_kernel<<<num_tokens, 1>>>(
global_topk.data_ptr<int32_t>(),
global_topk.stride(0),
topk_lens.data_ptr<int32_t>(),
local_topk.data_ptr<int32_t>(),
local_topk.stride(0),
topk_dim,
token_to_req.data_ptr<int32_t>(),
block_table.data_ptr<int32_t>(),
block_table.stride(0),
block_size,
is_valid_token.data_ptr<int32_t>()
);
return std::make_tuple(global_topk, topk_lens);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("build_c128a_topk_metadata", &build_c128a_topk_metadata_cuda,
"Build C128A topk metadata (global slot IDs + lengths)");
m.def("compute_c4a_global_topk", &compute_c4a_global_topk_cuda,
"Compute C4A global topk indices and lengths from local indices");
}