Files
vllm/csrc/topk.cu
2026-02-13 08:30:23 -08:00

373 lines
13 KiB
Plaintext

// Portions of this file are adapted from SGLang PR:
// https://github.com/sgl-project/sglang/pull/11194
// and
// https://github.com/sgl-project/sglang/pull/17747
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace vllm {
constexpr int TopK = 2048; // DeepSeek V3 sparse attention top-k
constexpr int kThreadsPerBlock = 1024; // Threads per block
// Shared memory budget
#if defined(USE_ROCM)
constexpr size_t kSmem = 48 * 1024; // ROCm default: 48KB
#else
// Reduced from 128KB to 32KB to improve occupancy.
// Each radix pass needs at most ~TopK candidates in the threshold bin,
// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient.
constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes)
#endif
struct FastTopKParams {
const float* __restrict__ input; // [batch, seq_len] Logits
const int32_t* __restrict__ row_starts; // [batch] Offset into each row
// (optional)
int32_t* __restrict__ indices; // [batch, TopK] Output top-k indices
int32_t* __restrict__ lengths; // [batch] Sequence lengths per row
int64_t input_stride; // Stride between rows
};
__device__ __forceinline__ auto convert_to_uint32_v2(float x) -> uint32_t {
uint32_t bits = __float_as_uint(x);
return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u);
}
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
__half h = __float2half_rn(x);
uint16_t bits = __half_as_ushort(h);
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits)
: static_cast<uint16_t>(bits | 0x8000);
return static_cast<uint8_t>(key >> 8);
}
__device__ void naive_topk_cuda(const float* __restrict__ logits,
int32_t* __restrict__ output_indices,
int32_t seq_len) {
const int thread_id = threadIdx.x;
for (int i = thread_id; i < TopK; i += kThreadsPerBlock) {
output_indices[i] = (i < seq_len) ? i : -1;
}
}
// Adapted from:
// https://github.com/sgl-project/sglang/blob/v0.5.8/sgl-kernel/csrc/elementwise/topk.cu#L87
// by: DarkSharpness
// which at the same time is an optimized topk kernel copied from tilelang
// kernel
__device__ void fast_topk_cuda_tl(
const float* __restrict__ logits, // Input logits [seq_len]
int* __restrict__ output_indices, // Output top-k indices [TopK]
int logits_offset, // Starting offset in logits array
int seq_len) // Number of valid logits to process
{
constexpr int RADIX = 256;
constexpr int MAX_BUFFERED_ITEMS = kSmem / (2 * sizeof(int));
alignas(128) __shared__ int shared_histogram[2][RADIX + 128];
alignas(128) __shared__ int shared_output_count;
alignas(128) __shared__ int shared_threshold_bin;
alignas(128) __shared__ int shared_buffered_count[2];
extern __shared__ int buffered_indices[][MAX_BUFFERED_ITEMS];
const int thread_id = threadIdx.x;
int remaining_k = TopK;
// Pass 0: Build coarse 8-bit histogram using FP16 high bits
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
}
__syncthreads();
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const auto bin = convert_to_uint8(logits[idx + logits_offset]);
::atomicAdd(&shared_histogram[0][bin], 1);
}
__syncthreads();
// Helper: Compute cumulative sum (suffix sum) over histogram using ping-pong
// buffers
auto compute_cumulative_sum = [&]() {
static_assert(1 << 8 == RADIX,
"Radix must be 256 for 8 unrolled iterations");
#pragma unroll 8
for (int i = 0; i < 8; ++i) {
if (C10_LIKELY(thread_id < RADIX)) {
const int stride = 1 << i;
const int src_buffer = i & 1;
const int dst_buffer = src_buffer ^ 1;
int value = shared_histogram[src_buffer][thread_id];
if (thread_id < RADIX - stride) {
value += shared_histogram[src_buffer][thread_id + stride];
}
shared_histogram[dst_buffer][thread_id] = value;
}
__syncthreads();
}
};
compute_cumulative_sum();
// Find threshold bin where cumsum crosses remaining_k
if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
shared_histogram[0][thread_id + 1] <= remaining_k) {
shared_threshold_bin = thread_id;
shared_buffered_count[0] = 0;
shared_output_count = 0;
}
__syncthreads();
const int threshold_bin = shared_threshold_bin;
remaining_k -= shared_histogram[0][threshold_bin + 1];
// Early exit if threshold bin perfectly matches remaining_k
if (remaining_k == 0) {
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const int bin = convert_to_uint8(logits[idx + logits_offset]);
if (bin > threshold_bin) {
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
}
}
__syncthreads();
return;
}
// Prepare for refinement passes: Process threshold bin
__syncthreads();
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
}
__syncthreads();
// Scan all elements and:
// 1. Write indices > threshold_bin to output
// 2. Buffer indices == threshold_bin for refinement
// 3. Build histogram for next refinement pass (fused optimization)
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const float logit_value = logits[idx + logits_offset];
const int bin = convert_to_uint8(logit_value);
if (bin > threshold_bin) {
// in top-k, write to output
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
} else if (bin == threshold_bin) {
// Candidate for top-k, needs refinement
const int buffer_pos = ::atomicAdd(&shared_buffered_count[0], 1);
if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) {
buffered_indices[0][buffer_pos] = idx;
// Fused: Build histogram for next pass
const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
const int next_bin = (fp32_bits >> 24) & 0xFF;
::atomicAdd(&shared_histogram[0][next_bin], 1);
}
}
}
__syncthreads();
// ============================================================================
// Passes 1-4: Refine using 8-bit passes over FP32 bits
// ============================================================================
// FP32 bits [31:0] split into 4 bytes processed MSB-first:
// Pass 1: bits [31:24], Pass 2: bits [23:16], Pass 3: bits [15:8], Pass 4:
// bits [7:0]
#pragma unroll 4
for (int pass = 0; pass < 4; ++pass) {
__shared__ int shared_final_k; // For final pass: remaining slots to fill
const int src_buffer = pass % 2;
const int dst_buffer = src_buffer ^ 1;
// Clamp buffered count to prevent overflow
const int raw_buffered = shared_buffered_count[src_buffer];
const int num_buffered =
(raw_buffered < MAX_BUFFERED_ITEMS) ? raw_buffered : MAX_BUFFERED_ITEMS;
compute_cumulative_sum();
// Find threshold bin for this pass
if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
shared_histogram[0][thread_id + 1] <= remaining_k) {
shared_threshold_bin = thread_id;
shared_buffered_count[dst_buffer] = 0;
shared_final_k = remaining_k - shared_histogram[0][thread_id + 1];
}
__syncthreads();
const int threshold_bin = shared_threshold_bin;
remaining_k -= shared_histogram[0][threshold_bin + 1];
// Bit offset for this pass: 24, 16, 8, 0
const int bit_offset = 24 - pass * 8;
// Early exit if threshold bin perfectly matches
if (remaining_k == 0) {
for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
const int idx = buffered_indices[src_buffer][i];
const uint32_t fp32_bits =
convert_to_uint32_v2(logits[idx + logits_offset]);
const int bin = (fp32_bits >> bit_offset) & 0xFF;
if (bin > threshold_bin) {
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
}
}
__syncthreads();
break;
}
// Continue refinement
__syncthreads();
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
}
__syncthreads();
for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
const int idx = buffered_indices[src_buffer][i];
const float logit_value = logits[idx + logits_offset];
const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
const int bin = (fp32_bits >> bit_offset) & 0xFF;
if (bin > threshold_bin) {
// Definitely in top-k
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
} else if (bin == threshold_bin) {
if (pass == 3) {
// Final pass (bits [7:0]): No more refinement possible
// Fill remaining slots in reverse order to maintain descending order
const int slot = ::atomicAdd(&shared_final_k, -1);
if (slot > 0) {
output_indices[TopK - slot] = idx;
}
} else {
// Buffer for next pass and build next histogram
const int buffer_pos =
::atomicAdd(&shared_buffered_count[dst_buffer], 1);
if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) {
buffered_indices[dst_buffer][buffer_pos] = idx;
// Fused: Build histogram for next pass
const int next_bit_offset = bit_offset - 8;
const int next_bin = (fp32_bits >> next_bit_offset) & 0xFF;
::atomicAdd(&shared_histogram[0][next_bin], 1);
}
}
}
}
__syncthreads();
}
}
__global__ __launch_bounds__(kThreadsPerBlock) void topk_kernel(
const FastTopKParams params) {
const auto& [input, row_starts, indices, lengths, input_stride] = params;
const uint64_t batch_idx = blockIdx.x;
const int logits_offset = row_starts == nullptr ? 0 : row_starts[batch_idx];
const int seq_len = lengths[batch_idx];
int* output_indices = indices + batch_idx * TopK;
const float* logits = input + batch_idx * input_stride;
if (seq_len <= TopK) {
// Shortcut: All elements are in top-k
return naive_topk_cuda(logits, output_indices, seq_len);
} else {
return fast_topk_cuda_tl(logits, output_indices, logits_offset, seq_len);
}
}
FastTopKParams get_params(
const at::Tensor& score, const at::Tensor& lengths,
std::optional<at::Tensor> row_starts_opt = std::nullopt,
std::optional<at::Tensor> indices_opt = std::nullopt) {
const int64_t batch_size = score.size(0);
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1,
"score must be 2D with contiguous rows");
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous() &&
lengths.size(0) == batch_size,
"lengths must be 1D contiguous with size matching batch");
const int32_t* row_starts_ptr = nullptr;
if (row_starts_opt.has_value()) {
const auto& row_starts = *row_starts_opt;
TORCH_CHECK(row_starts.dim() == 1 && row_starts.size(0) == batch_size,
"row_starts must be 1D with size matching batch");
row_starts_ptr = row_starts.data_ptr<int32_t>();
}
int32_t* indices_ptr = nullptr;
if (indices_opt.has_value()) {
const auto& indices = *indices_opt;
TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous() &&
indices.size(0) == batch_size && indices.size(1) == TopK,
"indices must be 2D contiguous [batch, TopK]");
indices_ptr = indices.data_ptr<int32_t>();
}
return FastTopKParams{
.input = score.data_ptr<float>(),
.row_starts = row_starts_ptr,
.indices = indices_ptr,
.lengths = lengths.data_ptr<int32_t>(),
.input_stride = score.stride(0),
};
}
template <auto* kernel_func, size_t smem_bytes>
void setup_kernel_smem_once() {
static const cudaError_t result = []() -> cudaError_t {
#ifdef USE_ROCM
auto func_ptr = reinterpret_cast<const void*>(kernel_func);
#else
auto func_ptr = kernel_func;
#endif
return cudaFuncSetAttribute(
func_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
}();
TORCH_CHECK(
result == cudaSuccess,
"Failed to set kernel shared memory limit: ", cudaGetErrorString(result));
}
} // namespace vllm
void large_context_topk(
const torch::Tensor& logits, torch::Tensor& indices,
const torch::Tensor& seq_lens,
std::optional<torch::Tensor> row_starts = std::nullopt) {
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor");
TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor");
TORCH_CHECK(seq_lens.is_cuda(), "seq_lens must be a CUDA tensor");
if (row_starts.has_value()) {
TORCH_CHECK(row_starts->is_cuda(), "row_starts must be a CUDA tensor");
}
const auto params = vllm::get_params(logits, seq_lens, row_starts, indices);
const int64_t batch_size = logits.size(0);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const dim3 grid(static_cast<uint32_t>(batch_size));
const dim3 block(vllm::kThreadsPerBlock);
vllm::setup_kernel_smem_once<vllm::topk_kernel, vllm::kSmem>();
vllm::topk_kernel<<<grid, block, vllm::kSmem, stream>>>(params);
const cudaError_t result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess,
"large_context_topk kernel failed: ", cudaGetErrorString(result));
}