// Portions of this file are adapted from SGLang PR: // https://github.com/sgl-project/sglang/pull/11194 // and // https://github.com/sgl-project/sglang/pull/17747 #include "cuda_compat.h" #include "dispatch_utils.h" #include #include #ifndef USE_ROCM #include #else #include #endif namespace vllm { constexpr int TopK = 2048; // DeepSeek V3 sparse attention top-k constexpr int kThreadsPerBlock = 1024; // Threads per block // Shared memory budget #if defined(USE_ROCM) constexpr size_t kSmem = 48 * 1024; // ROCm default: 48KB #else // Reduced from 128KB to 32KB to improve occupancy. // Each radix pass needs at most ~TopK candidates in the threshold bin, // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) #endif struct FastTopKParams { const float* __restrict__ input; // [batch, seq_len] Logits const int32_t* __restrict__ row_starts; // [batch] Offset into each row // (optional) int32_t* __restrict__ indices; // [batch, TopK] Output top-k indices int32_t* __restrict__ lengths; // [batch] Sequence lengths per row int64_t input_stride; // Stride between rows }; __device__ __forceinline__ auto convert_to_uint32_v2(float x) -> uint32_t { uint32_t bits = __float_as_uint(x); return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); } __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { __half h = __float2half_rn(x); uint16_t bits = __half_as_ushort(h); uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); return static_cast(key >> 8); } __device__ void naive_topk_cuda(const float* __restrict__ logits, int32_t* __restrict__ output_indices, int32_t seq_len) { const int thread_id = threadIdx.x; for (int i = thread_id; i < TopK; i += kThreadsPerBlock) { output_indices[i] = (i < seq_len) ? i : -1; } } // Adapted from: // https://github.com/sgl-project/sglang/blob/v0.5.8/sgl-kernel/csrc/elementwise/topk.cu#L87 // by: DarkSharpness // which at the same time is an optimized topk kernel copied from tilelang // kernel __device__ void fast_topk_cuda_tl( const float* __restrict__ logits, // Input logits [seq_len] int* __restrict__ output_indices, // Output top-k indices [TopK] int logits_offset, // Starting offset in logits array int seq_len) // Number of valid logits to process { constexpr int RADIX = 256; constexpr int MAX_BUFFERED_ITEMS = kSmem / (2 * sizeof(int)); alignas(128) __shared__ int shared_histogram[2][RADIX + 128]; alignas(128) __shared__ int shared_output_count; alignas(128) __shared__ int shared_threshold_bin; alignas(128) __shared__ int shared_buffered_count[2]; extern __shared__ int buffered_indices[][MAX_BUFFERED_ITEMS]; const int thread_id = threadIdx.x; int remaining_k = TopK; // Pass 0: Build coarse 8-bit histogram using FP16 high bits if (thread_id < RADIX + 1) { shared_histogram[0][thread_id] = 0; } __syncthreads(); for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) { const auto bin = convert_to_uint8(logits[idx + logits_offset]); ::atomicAdd(&shared_histogram[0][bin], 1); } __syncthreads(); // Helper: Compute cumulative sum (suffix sum) over histogram using ping-pong // buffers auto compute_cumulative_sum = [&]() { static_assert(1 << 8 == RADIX, "Radix must be 256 for 8 unrolled iterations"); #pragma unroll 8 for (int i = 0; i < 8; ++i) { if (C10_LIKELY(thread_id < RADIX)) { const int stride = 1 << i; const int src_buffer = i & 1; const int dst_buffer = src_buffer ^ 1; int value = shared_histogram[src_buffer][thread_id]; if (thread_id < RADIX - stride) { value += shared_histogram[src_buffer][thread_id + stride]; } shared_histogram[dst_buffer][thread_id] = value; } __syncthreads(); } }; compute_cumulative_sum(); // Find threshold bin where cumsum crosses remaining_k if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k && shared_histogram[0][thread_id + 1] <= remaining_k) { shared_threshold_bin = thread_id; shared_buffered_count[0] = 0; shared_output_count = 0; } __syncthreads(); const int threshold_bin = shared_threshold_bin; remaining_k -= shared_histogram[0][threshold_bin + 1]; // Early exit if threshold bin perfectly matches remaining_k if (remaining_k == 0) { for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) { const int bin = convert_to_uint8(logits[idx + logits_offset]); if (bin > threshold_bin) { const int output_pos = ::atomicAdd(&shared_output_count, 1); output_indices[output_pos] = idx; } } __syncthreads(); return; } // Prepare for refinement passes: Process threshold bin __syncthreads(); if (thread_id < RADIX + 1) { shared_histogram[0][thread_id] = 0; } __syncthreads(); // Scan all elements and: // 1. Write indices > threshold_bin to output // 2. Buffer indices == threshold_bin for refinement // 3. Build histogram for next refinement pass (fused optimization) for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) { const float logit_value = logits[idx + logits_offset]; const int bin = convert_to_uint8(logit_value); if (bin > threshold_bin) { // in top-k, write to output const int output_pos = ::atomicAdd(&shared_output_count, 1); output_indices[output_pos] = idx; } else if (bin == threshold_bin) { // Candidate for top-k, needs refinement const int buffer_pos = ::atomicAdd(&shared_buffered_count[0], 1); if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) { buffered_indices[0][buffer_pos] = idx; // Fused: Build histogram for next pass const uint32_t fp32_bits = convert_to_uint32_v2(logit_value); const int next_bin = (fp32_bits >> 24) & 0xFF; ::atomicAdd(&shared_histogram[0][next_bin], 1); } } } __syncthreads(); // ============================================================================ // Passes 1-4: Refine using 8-bit passes over FP32 bits // ============================================================================ // FP32 bits [31:0] split into 4 bytes processed MSB-first: // Pass 1: bits [31:24], Pass 2: bits [23:16], Pass 3: bits [15:8], Pass 4: // bits [7:0] #pragma unroll 4 for (int pass = 0; pass < 4; ++pass) { __shared__ int shared_final_k; // For final pass: remaining slots to fill const int src_buffer = pass % 2; const int dst_buffer = src_buffer ^ 1; // Clamp buffered count to prevent overflow const int raw_buffered = shared_buffered_count[src_buffer]; const int num_buffered = (raw_buffered < MAX_BUFFERED_ITEMS) ? raw_buffered : MAX_BUFFERED_ITEMS; compute_cumulative_sum(); // Find threshold bin for this pass if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k && shared_histogram[0][thread_id + 1] <= remaining_k) { shared_threshold_bin = thread_id; shared_buffered_count[dst_buffer] = 0; shared_final_k = remaining_k - shared_histogram[0][thread_id + 1]; } __syncthreads(); const int threshold_bin = shared_threshold_bin; remaining_k -= shared_histogram[0][threshold_bin + 1]; // Bit offset for this pass: 24, 16, 8, 0 const int bit_offset = 24 - pass * 8; // Early exit if threshold bin perfectly matches if (remaining_k == 0) { for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) { const int idx = buffered_indices[src_buffer][i]; const uint32_t fp32_bits = convert_to_uint32_v2(logits[idx + logits_offset]); const int bin = (fp32_bits >> bit_offset) & 0xFF; if (bin > threshold_bin) { const int output_pos = ::atomicAdd(&shared_output_count, 1); output_indices[output_pos] = idx; } } __syncthreads(); break; } // Continue refinement __syncthreads(); if (thread_id < RADIX + 1) { shared_histogram[0][thread_id] = 0; } __syncthreads(); for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) { const int idx = buffered_indices[src_buffer][i]; const float logit_value = logits[idx + logits_offset]; const uint32_t fp32_bits = convert_to_uint32_v2(logit_value); const int bin = (fp32_bits >> bit_offset) & 0xFF; if (bin > threshold_bin) { // Definitely in top-k const int output_pos = ::atomicAdd(&shared_output_count, 1); output_indices[output_pos] = idx; } else if (bin == threshold_bin) { if (pass == 3) { // Final pass (bits [7:0]): No more refinement possible // Fill remaining slots in reverse order to maintain descending order const int slot = ::atomicAdd(&shared_final_k, -1); if (slot > 0) { output_indices[TopK - slot] = idx; } } else { // Buffer for next pass and build next histogram const int buffer_pos = ::atomicAdd(&shared_buffered_count[dst_buffer], 1); if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) { buffered_indices[dst_buffer][buffer_pos] = idx; // Fused: Build histogram for next pass const int next_bit_offset = bit_offset - 8; const int next_bin = (fp32_bits >> next_bit_offset) & 0xFF; ::atomicAdd(&shared_histogram[0][next_bin], 1); } } } } __syncthreads(); } } __global__ __launch_bounds__(kThreadsPerBlock) void topk_kernel( const FastTopKParams params) { const auto& [input, row_starts, indices, lengths, input_stride] = params; const uint64_t batch_idx = blockIdx.x; const int logits_offset = row_starts == nullptr ? 0 : row_starts[batch_idx]; const int seq_len = lengths[batch_idx]; int* output_indices = indices + batch_idx * TopK; const float* logits = input + batch_idx * input_stride; if (seq_len <= TopK) { // Shortcut: All elements are in top-k return naive_topk_cuda(logits, output_indices, seq_len); } else { return fast_topk_cuda_tl(logits, output_indices, logits_offset, seq_len); } } FastTopKParams get_params( const at::Tensor& score, const at::Tensor& lengths, std::optional row_starts_opt = std::nullopt, std::optional indices_opt = std::nullopt) { const int64_t batch_size = score.size(0); TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1, "score must be 2D with contiguous rows"); TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous() && lengths.size(0) == batch_size, "lengths must be 1D contiguous with size matching batch"); const int32_t* row_starts_ptr = nullptr; if (row_starts_opt.has_value()) { const auto& row_starts = *row_starts_opt; TORCH_CHECK(row_starts.dim() == 1 && row_starts.size(0) == batch_size, "row_starts must be 1D with size matching batch"); row_starts_ptr = row_starts.data_ptr(); } int32_t* indices_ptr = nullptr; if (indices_opt.has_value()) { const auto& indices = *indices_opt; TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous() && indices.size(0) == batch_size && indices.size(1) == TopK, "indices must be 2D contiguous [batch, TopK]"); indices_ptr = indices.data_ptr(); } return FastTopKParams{ .input = score.data_ptr(), .row_starts = row_starts_ptr, .indices = indices_ptr, .lengths = lengths.data_ptr(), .input_stride = score.stride(0), }; } template void setup_kernel_smem_once() { static const cudaError_t result = []() -> cudaError_t { #ifdef USE_ROCM auto func_ptr = reinterpret_cast(kernel_func); #else auto func_ptr = kernel_func; #endif return cudaFuncSetAttribute( func_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); }(); TORCH_CHECK( result == cudaSuccess, "Failed to set kernel shared memory limit: ", cudaGetErrorString(result)); } } // namespace vllm void large_context_topk( const torch::Tensor& logits, torch::Tensor& indices, const torch::Tensor& seq_lens, std::optional row_starts = std::nullopt) { TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor"); TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor"); TORCH_CHECK(seq_lens.is_cuda(), "seq_lens must be a CUDA tensor"); if (row_starts.has_value()) { TORCH_CHECK(row_starts->is_cuda(), "row_starts must be a CUDA tensor"); } const auto params = vllm::get_params(logits, seq_lens, row_starts, indices); const int64_t batch_size = logits.size(0); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const dim3 grid(static_cast(batch_size)); const dim3 block(vllm::kThreadsPerBlock); vllm::setup_kernel_smem_once(); vllm::topk_kernel<<>>(params); const cudaError_t result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, "large_context_topk kernel failed: ", cudaGetErrorString(result)); }