[Perf][Kernel] Persistent TopK scheduler: unified CUDAGraph-safe kernel with dynamic per-row dispatch - DeepSeek-V3.2 DSA decode (#37421)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
75e01a39a1
commit
b55d830ec7
@@ -18,10 +18,9 @@ steps:
|
|||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/
|
- csrc/
|
||||||
- tests/kernels/core
|
- tests/kernels/core
|
||||||
- tests/kernels/test_top_k_per_row.py
|
|
||||||
- tests/kernels/test_concat_mla_q.py
|
- tests/kernels/test_concat_mla_q.py
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py
|
- pytest -v -s kernels/core kernels/test_concat_mla_q.py
|
||||||
|
|
||||||
- label: Kernels Attention Test %N
|
- label: Kernels Attention Test %N
|
||||||
timeout_in_minutes: 35
|
timeout_in_minutes: 35
|
||||||
@@ -107,6 +106,7 @@ steps:
|
|||||||
- vllm/v1/attention/backends/mla/flashinfer_mla.py
|
- vllm/v1/attention/backends/mla/flashinfer_mla.py
|
||||||
- vllm/v1/attention/selector.py
|
- vllm/v1/attention/selector.py
|
||||||
- vllm/platforms/cuda.py
|
- vllm/platforms/cuda.py
|
||||||
|
- tests/kernels/test_top_k_per_row.py
|
||||||
commands:
|
commands:
|
||||||
- nvidia-smi
|
- nvidia-smi
|
||||||
- python3 examples/basic/offline_inference/chat.py
|
- python3 examples/basic/offline_inference/chat.py
|
||||||
@@ -117,6 +117,7 @@ steps:
|
|||||||
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
|
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
|
||||||
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
|
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
|
||||||
- pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py
|
- pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py
|
||||||
|
- pytest -v -s tests/kernels/test_top_k_per_row.py
|
||||||
# Quantization
|
# Quantization
|
||||||
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
||||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
||||||
|
|||||||
@@ -114,9 +114,9 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
|||||||
int64_t numRows, int64_t stride0, int64_t stride1,
|
int64_t numRows, int64_t stride0, int64_t stride1,
|
||||||
int64_t topK);
|
int64_t topK);
|
||||||
|
|
||||||
void large_context_topk(const torch::Tensor& score, torch::Tensor& indices,
|
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||||
const torch::Tensor& lengths,
|
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
|
||||||
std::optional<torch::Tensor> row_starts_opt);
|
int64_t max_seq_len);
|
||||||
|
|
||||||
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& weight, torch::Tensor& scale,
|
torch::Tensor& weight, torch::Tensor& scale,
|
||||||
|
|||||||
1321
csrc/persistent_topk.cuh
Normal file
1321
csrc/persistent_topk.cuh
Normal file
File diff suppressed because it is too large
Load Diff
483
csrc/topk.cu
483
csrc/topk.cu
@@ -1,373 +1,154 @@
|
|||||||
// Portions of this file are adapted from SGLang PR:
|
// Persistent TopK kernel for DeepSeek V3 sparse attention indexer.
|
||||||
// https://github.com/sgl-project/sglang/pull/11194
|
// See persistent_topk.cuh for kernel implementation.
|
||||||
// and
|
|
||||||
// https://github.com/sgl-project/sglang/pull/17747
|
|
||||||
|
|
||||||
#include "cuda_compat.h"
|
#include <torch/all.h>
|
||||||
#include "dispatch_utils.h"
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
#include <torch/cuda.h>
|
#include <algorithm>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#include <cub/cub.cuh>
|
#include "persistent_topk.cuh"
|
||||||
#else
|
|
||||||
#include <hipcub/hipcub.hpp>
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace vllm {
|
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||||
|
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
|
||||||
|
int64_t max_seq_len) {
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
|
||||||
|
TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
|
||||||
|
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
|
||||||
|
TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported");
|
||||||
|
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
|
||||||
|
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
|
||||||
|
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
|
||||||
|
TORCH_CHECK(lengths.dim() == 1, "lengths must be 1D");
|
||||||
|
TORCH_CHECK(output.dim() == 2, "output must be 2D");
|
||||||
|
|
||||||
constexpr int TopK = 2048; // DeepSeek V3 sparse attention top-k
|
const int64_t num_rows = logits.size(0);
|
||||||
constexpr int kThreadsPerBlock = 1024; // Threads per block
|
const int64_t stride = logits.size(1);
|
||||||
|
|
||||||
// Shared memory budget
|
TORCH_CHECK(lengths.size(0) == num_rows, "lengths size mismatch");
|
||||||
#if defined(USE_ROCM)
|
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
|
||||||
constexpr size_t kSmem = 48 * 1024; // ROCm default: 48KB
|
"output size mismatch");
|
||||||
#else
|
namespace P = vllm::persistent;
|
||||||
// Reduced from 128KB to 32KB to improve occupancy.
|
|
||||||
// Each radix pass needs at most ~TopK candidates in the threshold bin,
|
|
||||||
// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient.
|
|
||||||
constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
struct FastTopKParams {
|
TORCH_CHECK(k == P::TopK, "k must be 2048");
|
||||||
const float* __restrict__ input; // [batch, seq_len] Logits
|
TORCH_CHECK(k <= stride, "k out of range");
|
||||||
const int32_t* __restrict__ row_starts; // [batch] Offset into each row
|
|
||||||
// (optional)
|
|
||||||
int32_t* __restrict__ indices; // [batch, TopK] Output top-k indices
|
|
||||||
int32_t* __restrict__ lengths; // [batch] Sequence lengths per row
|
|
||||||
int64_t input_stride; // Stride between rows
|
|
||||||
};
|
|
||||||
|
|
||||||
__device__ __forceinline__ auto convert_to_uint32_v2(float x) -> uint32_t {
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
uint32_t bits = __float_as_uint(x);
|
|
||||||
return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u);
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
|
static int num_sms = 0;
|
||||||
__half h = __float2half_rn(x);
|
static int max_smem_per_block = 0;
|
||||||
uint16_t bits = __half_as_ushort(h);
|
if (num_sms == 0) {
|
||||||
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits)
|
int device;
|
||||||
: static_cast<uint16_t>(bits | 0x8000);
|
cudaGetDevice(&device);
|
||||||
return static_cast<uint8_t>(key >> 8);
|
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device);
|
||||||
}
|
cudaDeviceGetAttribute(&max_smem_per_block,
|
||||||
|
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||||
__device__ void naive_topk_cuda(const float* __restrict__ logits,
|
|
||||||
int32_t* __restrict__ output_indices,
|
|
||||||
int32_t seq_len) {
|
|
||||||
const int thread_id = threadIdx.x;
|
|
||||||
for (int i = thread_id; i < TopK; i += kThreadsPerBlock) {
|
|
||||||
output_indices[i] = (i < seq_len) ? i : -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adapted from:
|
|
||||||
// https://github.com/sgl-project/sglang/blob/v0.5.8/sgl-kernel/csrc/elementwise/topk.cu#L87
|
|
||||||
// by: DarkSharpness
|
|
||||||
// which at the same time is an optimized topk kernel copied from tilelang
|
|
||||||
// kernel
|
|
||||||
__device__ void fast_topk_cuda_tl(
|
|
||||||
const float* __restrict__ logits, // Input logits [seq_len]
|
|
||||||
int* __restrict__ output_indices, // Output top-k indices [TopK]
|
|
||||||
int logits_offset, // Starting offset in logits array
|
|
||||||
int seq_len) // Number of valid logits to process
|
|
||||||
{
|
|
||||||
constexpr int RADIX = 256;
|
|
||||||
constexpr int MAX_BUFFERED_ITEMS = kSmem / (2 * sizeof(int));
|
|
||||||
|
|
||||||
alignas(128) __shared__ int shared_histogram[2][RADIX + 128];
|
|
||||||
alignas(128) __shared__ int shared_output_count;
|
|
||||||
alignas(128) __shared__ int shared_threshold_bin;
|
|
||||||
alignas(128) __shared__ int shared_buffered_count[2];
|
|
||||||
|
|
||||||
extern __shared__ int buffered_indices[][MAX_BUFFERED_ITEMS];
|
|
||||||
|
|
||||||
const int thread_id = threadIdx.x;
|
|
||||||
int remaining_k = TopK;
|
|
||||||
|
|
||||||
// Pass 0: Build coarse 8-bit histogram using FP16 high bits
|
|
||||||
if (thread_id < RADIX + 1) {
|
|
||||||
shared_histogram[0][thread_id] = 0;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
|
|
||||||
const auto bin = convert_to_uint8(logits[idx + logits_offset]);
|
|
||||||
::atomicAdd(&shared_histogram[0][bin], 1);
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Helper: Compute cumulative sum (suffix sum) over histogram using ping-pong
|
|
||||||
// buffers
|
|
||||||
auto compute_cumulative_sum = [&]() {
|
|
||||||
static_assert(1 << 8 == RADIX,
|
|
||||||
"Radix must be 256 for 8 unrolled iterations");
|
|
||||||
#pragma unroll 8
|
|
||||||
for (int i = 0; i < 8; ++i) {
|
|
||||||
if (C10_LIKELY(thread_id < RADIX)) {
|
|
||||||
const int stride = 1 << i;
|
|
||||||
const int src_buffer = i & 1;
|
|
||||||
const int dst_buffer = src_buffer ^ 1;
|
|
||||||
|
|
||||||
int value = shared_histogram[src_buffer][thread_id];
|
|
||||||
if (thread_id < RADIX - stride) {
|
|
||||||
value += shared_histogram[src_buffer][thread_id + stride];
|
|
||||||
}
|
|
||||||
shared_histogram[dst_buffer][thread_id] = value;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
compute_cumulative_sum();
|
|
||||||
|
|
||||||
// Find threshold bin where cumsum crosses remaining_k
|
|
||||||
if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
|
|
||||||
shared_histogram[0][thread_id + 1] <= remaining_k) {
|
|
||||||
shared_threshold_bin = thread_id;
|
|
||||||
shared_buffered_count[0] = 0;
|
|
||||||
shared_output_count = 0;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
const int threshold_bin = shared_threshold_bin;
|
|
||||||
remaining_k -= shared_histogram[0][threshold_bin + 1];
|
|
||||||
|
|
||||||
// Early exit if threshold bin perfectly matches remaining_k
|
|
||||||
if (remaining_k == 0) {
|
|
||||||
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
|
|
||||||
const int bin = convert_to_uint8(logits[idx + logits_offset]);
|
|
||||||
if (bin > threshold_bin) {
|
|
||||||
const int output_pos = ::atomicAdd(&shared_output_count, 1);
|
|
||||||
output_indices[output_pos] = idx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare for refinement passes: Process threshold bin
|
if (num_rows > 32 && max_smem_per_block >= 128 * 1024) {
|
||||||
__syncthreads();
|
cudaError_t status = vllm::FilteredTopKRaggedTransform<float, int32_t>(
|
||||||
if (thread_id < RADIX + 1) {
|
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
|
||||||
shared_histogram[0][thread_id] = 0;
|
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
|
||||||
}
|
static_cast<uint32_t>(k), static_cast<uint32_t>(stride), stream);
|
||||||
__syncthreads();
|
TORCH_CHECK(status == cudaSuccess,
|
||||||
|
"FilteredTopK failed: ", cudaGetErrorString(status));
|
||||||
// Scan all elements and:
|
|
||||||
// 1. Write indices > threshold_bin to output
|
|
||||||
// 2. Buffer indices == threshold_bin for refinement
|
|
||||||
// 3. Build histogram for next refinement pass (fused optimization)
|
|
||||||
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
|
|
||||||
const float logit_value = logits[idx + logits_offset];
|
|
||||||
const int bin = convert_to_uint8(logit_value);
|
|
||||||
|
|
||||||
if (bin > threshold_bin) {
|
|
||||||
// in top-k, write to output
|
|
||||||
const int output_pos = ::atomicAdd(&shared_output_count, 1);
|
|
||||||
output_indices[output_pos] = idx;
|
|
||||||
} else if (bin == threshold_bin) {
|
|
||||||
// Candidate for top-k, needs refinement
|
|
||||||
const int buffer_pos = ::atomicAdd(&shared_buffered_count[0], 1);
|
|
||||||
if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) {
|
|
||||||
buffered_indices[0][buffer_pos] = idx;
|
|
||||||
// Fused: Build histogram for next pass
|
|
||||||
const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
|
|
||||||
const int next_bin = (fp32_bits >> 24) & 0xFF;
|
|
||||||
::atomicAdd(&shared_histogram[0][next_bin], 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
// Passes 1-4: Refine using 8-bit passes over FP32 bits
|
|
||||||
// ============================================================================
|
|
||||||
// FP32 bits [31:0] split into 4 bytes processed MSB-first:
|
|
||||||
// Pass 1: bits [31:24], Pass 2: bits [23:16], Pass 3: bits [15:8], Pass 4:
|
|
||||||
// bits [7:0]
|
|
||||||
#pragma unroll 4
|
|
||||||
for (int pass = 0; pass < 4; ++pass) {
|
|
||||||
__shared__ int shared_final_k; // For final pass: remaining slots to fill
|
|
||||||
const int src_buffer = pass % 2;
|
|
||||||
const int dst_buffer = src_buffer ^ 1;
|
|
||||||
|
|
||||||
// Clamp buffered count to prevent overflow
|
|
||||||
const int raw_buffered = shared_buffered_count[src_buffer];
|
|
||||||
const int num_buffered =
|
|
||||||
(raw_buffered < MAX_BUFFERED_ITEMS) ? raw_buffered : MAX_BUFFERED_ITEMS;
|
|
||||||
|
|
||||||
compute_cumulative_sum();
|
|
||||||
|
|
||||||
// Find threshold bin for this pass
|
|
||||||
if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
|
|
||||||
shared_histogram[0][thread_id + 1] <= remaining_k) {
|
|
||||||
shared_threshold_bin = thread_id;
|
|
||||||
shared_buffered_count[dst_buffer] = 0;
|
|
||||||
shared_final_k = remaining_k - shared_histogram[0][thread_id + 1];
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
const int threshold_bin = shared_threshold_bin;
|
|
||||||
remaining_k -= shared_histogram[0][threshold_bin + 1];
|
|
||||||
|
|
||||||
// Bit offset for this pass: 24, 16, 8, 0
|
|
||||||
const int bit_offset = 24 - pass * 8;
|
|
||||||
|
|
||||||
// Early exit if threshold bin perfectly matches
|
|
||||||
if (remaining_k == 0) {
|
|
||||||
for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
|
|
||||||
const int idx = buffered_indices[src_buffer][i];
|
|
||||||
const uint32_t fp32_bits =
|
|
||||||
convert_to_uint32_v2(logits[idx + logits_offset]);
|
|
||||||
const int bin = (fp32_bits >> bit_offset) & 0xFF;
|
|
||||||
if (bin > threshold_bin) {
|
|
||||||
const int output_pos = ::atomicAdd(&shared_output_count, 1);
|
|
||||||
output_indices[output_pos] = idx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Continue refinement
|
|
||||||
__syncthreads();
|
|
||||||
if (thread_id < RADIX + 1) {
|
|
||||||
shared_histogram[0][thread_id] = 0;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
|
|
||||||
const int idx = buffered_indices[src_buffer][i];
|
|
||||||
const float logit_value = logits[idx + logits_offset];
|
|
||||||
const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
|
|
||||||
const int bin = (fp32_bits >> bit_offset) & 0xFF;
|
|
||||||
|
|
||||||
if (bin > threshold_bin) {
|
|
||||||
// Definitely in top-k
|
|
||||||
const int output_pos = ::atomicAdd(&shared_output_count, 1);
|
|
||||||
output_indices[output_pos] = idx;
|
|
||||||
} else if (bin == threshold_bin) {
|
|
||||||
if (pass == 3) {
|
|
||||||
// Final pass (bits [7:0]): No more refinement possible
|
|
||||||
// Fill remaining slots in reverse order to maintain descending order
|
|
||||||
const int slot = ::atomicAdd(&shared_final_k, -1);
|
|
||||||
if (slot > 0) {
|
|
||||||
output_indices[TopK - slot] = idx;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// Buffer for next pass and build next histogram
|
TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor");
|
||||||
const int buffer_pos =
|
TORCH_CHECK(workspace.dtype() == torch::kUInt8, "workspace must be uint8");
|
||||||
::atomicAdd(&shared_buffered_count[dst_buffer], 1);
|
|
||||||
if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) {
|
|
||||||
buffered_indices[dst_buffer][buffer_pos] = idx;
|
|
||||||
// Fused: Build histogram for next pass
|
|
||||||
const int next_bit_offset = bit_offset - 8;
|
|
||||||
const int next_bin = (fp32_bits >> next_bit_offset) & 0xFF;
|
|
||||||
::atomicAdd(&shared_histogram[0][next_bin], 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ __launch_bounds__(kThreadsPerBlock) void topk_kernel(
|
// Smem cap: smaller smem → more CTAs/group → more per-row parallelism for
|
||||||
const FastTopKParams params) {
|
// large path. Empirically tuned.
|
||||||
const auto& [input, row_starts, indices, lengths, input_stride] = params;
|
int effective_max_smem;
|
||||||
const uint64_t batch_idx = blockIdx.x;
|
if (num_rows <= 4) {
|
||||||
const int logits_offset = row_starts == nullptr ? 0 : row_starts[batch_idx];
|
effective_max_smem =
|
||||||
const int seq_len = lengths[batch_idx];
|
std::min(max_smem_per_block, static_cast<int>(P::kSmemMedium));
|
||||||
int* output_indices = indices + batch_idx * TopK;
|
} else if (num_rows <= 8) {
|
||||||
const float* logits = input + batch_idx * input_stride;
|
constexpr int kSmemCapMedium = 48 * 1024;
|
||||||
|
effective_max_smem = std::min(max_smem_per_block, kSmemCapMedium);
|
||||||
if (seq_len <= TopK) {
|
|
||||||
// Shortcut: All elements are in top-k
|
|
||||||
return naive_topk_cuda(logits, output_indices, seq_len);
|
|
||||||
} else {
|
} else {
|
||||||
return fast_topk_cuda_tl(logits, output_indices, logits_offset, seq_len);
|
effective_max_smem = max_smem_per_block;
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
FastTopKParams get_params(
|
|
||||||
const at::Tensor& score, const at::Tensor& lengths,
|
|
||||||
std::optional<at::Tensor> row_starts_opt = std::nullopt,
|
|
||||||
std::optional<at::Tensor> indices_opt = std::nullopt) {
|
|
||||||
const int64_t batch_size = score.size(0);
|
|
||||||
|
|
||||||
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1,
|
|
||||||
"score must be 2D with contiguous rows");
|
|
||||||
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous() &&
|
|
||||||
lengths.size(0) == batch_size,
|
|
||||||
"lengths must be 1D contiguous with size matching batch");
|
|
||||||
|
|
||||||
const int32_t* row_starts_ptr = nullptr;
|
|
||||||
if (row_starts_opt.has_value()) {
|
|
||||||
const auto& row_starts = *row_starts_opt;
|
|
||||||
TORCH_CHECK(row_starts.dim() == 1 && row_starts.size(0) == batch_size,
|
|
||||||
"row_starts must be 1D with size matching batch");
|
|
||||||
row_starts_ptr = row_starts.data_ptr<int32_t>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t* indices_ptr = nullptr;
|
size_t available_for_ordered =
|
||||||
if (indices_opt.has_value()) {
|
static_cast<size_t>(effective_max_smem) - P::kFixedSmemLarge;
|
||||||
const auto& indices = *indices_opt;
|
uint32_t max_chunk_elements =
|
||||||
TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous() &&
|
static_cast<uint32_t>(available_for_ordered / sizeof(uint32_t));
|
||||||
indices.size(0) == batch_size && indices.size(1) == TopK,
|
|
||||||
"indices must be 2D contiguous [batch, TopK]");
|
uint32_t vec_size = 1;
|
||||||
indices_ptr = indices.data_ptr<int32_t>();
|
if (stride % 4 == 0)
|
||||||
|
vec_size = 4;
|
||||||
|
else if (stride % 2 == 0)
|
||||||
|
vec_size = 2;
|
||||||
|
|
||||||
|
max_chunk_elements = (max_chunk_elements / vec_size) * vec_size;
|
||||||
|
uint32_t min_chunk = vec_size * P::kThreadsPerBlock;
|
||||||
|
if (max_chunk_elements < min_chunk) max_chunk_elements = min_chunk;
|
||||||
|
|
||||||
|
uint32_t ctas_per_group =
|
||||||
|
(static_cast<uint32_t>(stride) + max_chunk_elements - 1) /
|
||||||
|
max_chunk_elements;
|
||||||
|
uint32_t chunk_size =
|
||||||
|
(static_cast<uint32_t>(stride) + ctas_per_group - 1) / ctas_per_group;
|
||||||
|
chunk_size = ((chunk_size + vec_size - 1) / vec_size) * vec_size;
|
||||||
|
if (chunk_size > max_chunk_elements) chunk_size = max_chunk_elements;
|
||||||
|
|
||||||
|
size_t smem_size = P::kFixedSmemLarge + chunk_size * sizeof(uint32_t);
|
||||||
|
if (smem_size < P::kSmemMedium) smem_size = P::kSmemMedium;
|
||||||
|
|
||||||
|
int occupancy = 1;
|
||||||
|
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||||
|
&occupancy, P::persistent_topk_kernel<4>, P::kThreadsPerBlock,
|
||||||
|
smem_size);
|
||||||
|
if (occupancy < 1) occupancy = 1;
|
||||||
|
|
||||||
|
uint32_t max_resident_ctas = static_cast<uint32_t>(num_sms) * occupancy;
|
||||||
|
uint32_t num_groups = std::min(max_resident_ctas / ctas_per_group,
|
||||||
|
static_cast<uint32_t>(num_rows));
|
||||||
|
if (num_groups == 0) num_groups = 1;
|
||||||
|
uint32_t total_ctas = num_groups * ctas_per_group;
|
||||||
|
|
||||||
|
size_t state_bytes = num_groups * sizeof(P::RadixRowState);
|
||||||
|
TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
|
||||||
|
"workspace too small, need ", state_bytes, " bytes");
|
||||||
|
|
||||||
|
P::PersistentTopKParams params;
|
||||||
|
params.input = logits.data_ptr<float>();
|
||||||
|
params.output = output.data_ptr<int32_t>();
|
||||||
|
params.lengths = lengths.data_ptr<int32_t>();
|
||||||
|
params.num_rows = static_cast<uint32_t>(num_rows);
|
||||||
|
params.stride = static_cast<uint32_t>(stride);
|
||||||
|
params.chunk_size = chunk_size;
|
||||||
|
params.row_states =
|
||||||
|
reinterpret_cast<P::RadixRowState*>(workspace.data_ptr<uint8_t>());
|
||||||
|
params.ctas_per_group = ctas_per_group;
|
||||||
|
params.max_seq_len = static_cast<uint32_t>(max_seq_len);
|
||||||
|
|
||||||
|
#define LAUNCH_PERSISTENT(VS) \
|
||||||
|
do { \
|
||||||
|
auto kernel = &P::persistent_topk_kernel<VS>; \
|
||||||
|
cudaError_t err = cudaFuncSetAttribute( \
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \
|
||||||
|
TORCH_CHECK(err == cudaSuccess, \
|
||||||
|
"Failed to set smem: ", cudaGetErrorString(err)); \
|
||||||
|
kernel<<<total_ctas, P::kThreadsPerBlock, smem_size, stream>>>(params); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
if (vec_size == 4) {
|
||||||
|
LAUNCH_PERSISTENT(4);
|
||||||
|
} else if (vec_size == 2) {
|
||||||
|
LAUNCH_PERSISTENT(2);
|
||||||
|
} else {
|
||||||
|
LAUNCH_PERSISTENT(1);
|
||||||
|
}
|
||||||
|
#undef LAUNCH_PERSISTENT
|
||||||
}
|
}
|
||||||
|
|
||||||
return FastTopKParams{
|
cudaError_t err = cudaGetLastError();
|
||||||
.input = score.data_ptr<float>(),
|
TORCH_CHECK(err == cudaSuccess,
|
||||||
.row_starts = row_starts_ptr,
|
"persistent_topk failed: ", cudaGetErrorString(err));
|
||||||
.indices = indices_ptr,
|
|
||||||
.lengths = lengths.data_ptr<int32_t>(),
|
|
||||||
.input_stride = score.stride(0),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
template <auto* kernel_func, size_t smem_bytes>
|
|
||||||
void setup_kernel_smem_once() {
|
|
||||||
static const cudaError_t result = []() -> cudaError_t {
|
|
||||||
#ifdef USE_ROCM
|
|
||||||
auto func_ptr = reinterpret_cast<const void*>(kernel_func);
|
|
||||||
#else
|
#else
|
||||||
auto func_ptr = kernel_func;
|
TORCH_CHECK(false, "persistent_topk is not supported on ROCm");
|
||||||
#endif
|
#endif
|
||||||
return cudaFuncSetAttribute(
|
|
||||||
func_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
|
|
||||||
}();
|
|
||||||
|
|
||||||
TORCH_CHECK(
|
|
||||||
result == cudaSuccess,
|
|
||||||
"Failed to set kernel shared memory limit: ", cudaGetErrorString(result));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
|
||||||
|
|
||||||
void large_context_topk(
|
|
||||||
const torch::Tensor& logits, torch::Tensor& indices,
|
|
||||||
const torch::Tensor& seq_lens,
|
|
||||||
std::optional<torch::Tensor> row_starts = std::nullopt) {
|
|
||||||
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor");
|
|
||||||
TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor");
|
|
||||||
TORCH_CHECK(seq_lens.is_cuda(), "seq_lens must be a CUDA tensor");
|
|
||||||
if (row_starts.has_value()) {
|
|
||||||
TORCH_CHECK(row_starts->is_cuda(), "row_starts must be a CUDA tensor");
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto params = vllm::get_params(logits, seq_lens, row_starts, indices);
|
|
||||||
const int64_t batch_size = logits.size(0);
|
|
||||||
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
||||||
const dim3 grid(static_cast<uint32_t>(batch_size));
|
|
||||||
const dim3 block(vllm::kThreadsPerBlock);
|
|
||||||
|
|
||||||
vllm::setup_kernel_smem_once<vllm::topk_kernel, vllm::kSmem>();
|
|
||||||
vllm::topk_kernel<<<grid, block, vllm::kSmem, stream>>>(params);
|
|
||||||
|
|
||||||
const cudaError_t result = cudaGetLastError();
|
|
||||||
TORCH_CHECK(result == cudaSuccess,
|
|
||||||
"large_context_topk kernel failed: ", cudaGetErrorString(result));
|
|
||||||
}
|
}
|
||||||
@@ -197,10 +197,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
|
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
|
||||||
|
|
||||||
ops.def(
|
ops.def(
|
||||||
"large_context_topk(Tensor score, Tensor indices, Tensor lengths, "
|
"persistent_topk(Tensor logits, Tensor lengths, Tensor! output, "
|
||||||
"Tensor? "
|
"Tensor workspace, int k, int max_seq_len) -> ()");
|
||||||
"row_starts_opt) -> ()");
|
ops.impl("persistent_topk", torch::kCUDA, &persistent_topk);
|
||||||
ops.impl("large_context_topk", torch::kCUDA, &large_context_topk);
|
|
||||||
|
|
||||||
// Layernorm-quant
|
// Layernorm-quant
|
||||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||||
|
|||||||
@@ -122,6 +122,39 @@ def compare_top_k_results(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def validate_topk_against_reference(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
cuda_indices: torch.Tensor,
|
||||||
|
row_starts: torch.Tensor,
|
||||||
|
row_ends: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
kernel_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Validate CUDA top-k results against PyTorch reference implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Input logits tensor
|
||||||
|
cuda_indices: CUDA kernel output indices
|
||||||
|
row_starts: Row start positions
|
||||||
|
row_ends: Row end positions
|
||||||
|
top_k: Number of top elements to select
|
||||||
|
kernel_name: Name of the kernel being tested (for error messages)
|
||||||
|
"""
|
||||||
|
num_rows = cuda_indices.shape[0]
|
||||||
|
torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
for i in range(num_rows):
|
||||||
|
row_end = int(row_ends[i])
|
||||||
|
k_i = min(top_k, row_end)
|
||||||
|
idx = logits[i, :row_end].topk(k_i, dim=-1)[1]
|
||||||
|
torch_indices[i, :k_i] = idx
|
||||||
|
|
||||||
|
assert compare_top_k_results(
|
||||||
|
logits, cuda_indices, torch_indices, row_starts, row_ends, top_k
|
||||||
|
), f"{kernel_name} results don't match torch.topk"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_rows", NUM_ROWS)
|
@pytest.mark.parametrize("num_rows", NUM_ROWS)
|
||||||
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
|
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
|
||||||
@pytest.mark.parametrize("clean_logits", [True, False])
|
@pytest.mark.parametrize("clean_logits", [True, False])
|
||||||
@@ -278,111 +311,540 @@ def test_top_k_per_row_decode_large_vocab_size(clean_logits: bool) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"seq_len_range,test_id",
|
||||||
|
[
|
||||||
|
pytest.param((4000, 8000), "short_sequences", id="short"),
|
||||||
|
pytest.param((8000, 32000), "medium_sequences", id="medium"),
|
||||||
|
pytest.param((32000, 163840), "long_sequences", id="long"),
|
||||||
|
],
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("clean_logits", [True, False])
|
@pytest.mark.parametrize("clean_logits", [True, False])
|
||||||
|
@pytest.mark.parametrize("top_k", [2048])
|
||||||
|
@pytest.mark.parametrize("next_n", [1, 4])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_deepseek_hybrid_topk(clean_logits: bool) -> None:
|
def test_deepseek_persistent_topk(
|
||||||
|
seq_len_range: tuple[int, int],
|
||||||
|
test_id: str,
|
||||||
|
clean_logits: bool,
|
||||||
|
top_k: int,
|
||||||
|
next_n: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test persistent_topk with varying sequence lengths and speculative decoding.
|
||||||
|
Supports speculative decoding with next_n > 1.
|
||||||
|
"""
|
||||||
|
set_random_seed(42 if test_id == "short_sequences" else 43)
|
||||||
|
torch.set_default_device("cuda:0")
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
num_rows = batch_size * next_n
|
||||||
|
|
||||||
|
seq_lens = torch.randint(
|
||||||
|
seq_len_range[0],
|
||||||
|
seq_len_range[1],
|
||||||
|
(batch_size,),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute row boundaries for speculative decoding
|
||||||
|
row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda")
|
||||||
|
row_indices = torch.arange(num_rows, device="cuda") // next_n
|
||||||
|
next_n_offset = torch.arange(num_rows, device="cuda") % next_n
|
||||||
|
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
|
||||||
|
|
||||||
|
logits = create_random_logits(
|
||||||
|
row_starts, row_ends, torch.float32, 42, clean_logits, "random"
|
||||||
|
)
|
||||||
|
|
||||||
|
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
if next_n == 1:
|
||||||
|
lengths = seq_lens
|
||||||
|
else:
|
||||||
|
offsets = torch.arange(next_n, device=logits.device, dtype=torch.int32)
|
||||||
|
lengths = (seq_lens.unsqueeze(1) - next_n + 1 + offsets).flatten()
|
||||||
|
|
||||||
|
workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda")
|
||||||
|
max_seq_len = int(seq_lens.max().item())
|
||||||
|
torch.ops._C.persistent_topk(
|
||||||
|
logits, lengths, indices, workspace, top_k, max_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_topk_against_reference(
|
||||||
|
logits, indices, row_starts, row_ends, top_k, f"persistent_topk ({test_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_large_context_topk_test(
|
||||||
|
batch_size: int,
|
||||||
|
seq_lens: list[int],
|
||||||
|
top_k: int,
|
||||||
|
data_type: str = "random",
|
||||||
|
seed: int = 42,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Helper to run persistent_topk kernel test with given parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: Number of rows/sequences
|
||||||
|
seq_lens: List of sequence lengths (one per row)
|
||||||
|
top_k: Number of top elements to select
|
||||||
|
data_type: Type of test data to generate
|
||||||
|
seed: Random seed for reproducibility
|
||||||
|
"""
|
||||||
|
torch.set_default_device("cuda:0")
|
||||||
|
set_random_seed(seed)
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
num_rows = batch_size
|
||||||
|
max_len = max(seq_lens)
|
||||||
|
lengths = torch.tensor(seq_lens, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
if data_type == "random":
|
||||||
|
logits = torch.randn(num_rows, max_len, dtype=torch.float32, device="cuda")
|
||||||
|
elif data_type == "sorted_asc":
|
||||||
|
# Each row gets its own ascending sequence based on its length
|
||||||
|
logits = torch.empty(num_rows, max_len, dtype=torch.float32, device="cuda")
|
||||||
|
for i, length in enumerate(seq_lens):
|
||||||
|
logits[i, :length] = torch.arange(
|
||||||
|
length, dtype=torch.float32, device="cuda"
|
||||||
|
)
|
||||||
|
if length < max_len:
|
||||||
|
logits[i, length:] = float("-inf")
|
||||||
|
elif data_type == "sorted_desc":
|
||||||
|
# Each row gets its own descending sequence based on its length
|
||||||
|
logits = torch.empty(num_rows, max_len, dtype=torch.float32, device="cuda")
|
||||||
|
for i, length in enumerate(seq_lens):
|
||||||
|
logits[i, :length] = torch.arange(
|
||||||
|
length, 0, -1, dtype=torch.float32, device="cuda"
|
||||||
|
)
|
||||||
|
if length < max_len:
|
||||||
|
logits[i, length:] = float("-inf")
|
||||||
|
elif data_type == "all_same":
|
||||||
|
logits = torch.ones(num_rows, max_len, dtype=torch.float32, device="cuda")
|
||||||
|
for i, length in enumerate(seq_lens):
|
||||||
|
if length < max_len:
|
||||||
|
logits[i, length:] = float("-inf")
|
||||||
|
elif data_type == "many_ties":
|
||||||
|
# Only 10 unique values, many duplicates
|
||||||
|
logits = torch.randint(0, 10, (num_rows, max_len), device="cuda").float() / 10.0
|
||||||
|
for i, length in enumerate(seq_lens):
|
||||||
|
if length < max_len:
|
||||||
|
logits[i, length:] = float("-inf")
|
||||||
|
elif data_type == "small_differences":
|
||||||
|
# Very small differences to test float precision
|
||||||
|
base = torch.randn(num_rows, max_len, dtype=torch.float32, device="cuda")
|
||||||
|
noise = (
|
||||||
|
torch.randn(num_rows, max_len, dtype=torch.float32, device="cuda") * 1e-6
|
||||||
|
)
|
||||||
|
logits = base + noise
|
||||||
|
for i, length in enumerate(seq_lens):
|
||||||
|
if length < max_len:
|
||||||
|
logits[i, length:] = float("-inf")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown data_type: {data_type}")
|
||||||
|
|
||||||
|
# Create output tensor
|
||||||
|
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda")
|
||||||
|
max_seq_len = max(seq_lens)
|
||||||
|
torch.ops._C.persistent_topk(
|
||||||
|
logits, lengths, indices, workspace, top_k, max_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.accelerator.synchronize()
|
||||||
|
|
||||||
|
torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||||
|
for i in range(num_rows):
|
||||||
|
length = seq_lens[i]
|
||||||
|
k_i = min(top_k, length)
|
||||||
|
if k_i > 0:
|
||||||
|
idx = logits[i, :length].topk(k_i, dim=-1)[1]
|
||||||
|
torch_indices[i, :k_i] = idx
|
||||||
|
if k_i < top_k:
|
||||||
|
torch_indices[i, k_i:] = -1
|
||||||
|
else:
|
||||||
|
torch_indices[i, :] = -1
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
for i in range(num_rows):
|
||||||
|
length = seq_lens[i]
|
||||||
|
k_i = min(top_k, length)
|
||||||
|
|
||||||
|
if k_i == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cuda_row = indices[i, :k_i].cpu()
|
||||||
|
torch_row = torch_indices[i, :k_i].cpu()
|
||||||
|
|
||||||
|
# Filter out -1 padding values from cuda_row
|
||||||
|
valid_mask = cuda_row >= 0
|
||||||
|
cuda_row = cuda_row[valid_mask]
|
||||||
|
|
||||||
|
# Compare sets (order may differ for ties)
|
||||||
|
cuda_set = set(cuda_row.tolist())
|
||||||
|
torch_set = set(torch_row.tolist())
|
||||||
|
|
||||||
|
if cuda_set == torch_set:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If sets differ, check if it's due to equal values (ties)
|
||||||
|
cuda_vals = logits[i, cuda_row].cpu()
|
||||||
|
torch_vals = logits[i, torch_row].cpu()
|
||||||
|
|
||||||
|
# Check that min CUDA value >= max of values NOT in top-k
|
||||||
|
if k_i < length:
|
||||||
|
non_topk_indices = torch.tensor(
|
||||||
|
list(set(range(length)) - cuda_set), dtype=torch.int32
|
||||||
|
)
|
||||||
|
if len(non_topk_indices) > 0:
|
||||||
|
non_topk_vals = logits[i, non_topk_indices].cpu()
|
||||||
|
min_cuda_val = cuda_vals.min()
|
||||||
|
max_non_topk = non_topk_vals.max()
|
||||||
|
|
||||||
|
# Allow small tolerance for floating point errors
|
||||||
|
assert min_cuda_val >= max_non_topk - 1e-4, (
|
||||||
|
f"Row {i}: CUDA top-k contains values smaller than non-top-k. "
|
||||||
|
f"Min CUDA: {min_cuda_val}, Max non-top-k: {max_non_topk}, "
|
||||||
|
f"Length: {length}, k: {k_i}, CUDA indices: {sorted(cuda_set)[:10]}..., " # noqa: E501
|
||||||
|
f"Expected indices: {sorted(torch_set)[:10]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# For ties, verify the values are close
|
||||||
|
assert torch.allclose(
|
||||||
|
cuda_vals.sort(descending=True)[0],
|
||||||
|
torch_vals.sort(descending=True)[0],
|
||||||
|
rtol=1e-4,
|
||||||
|
atol=1e-4,
|
||||||
|
), f"""Row {i}: Top-k values don't match.
|
||||||
|
CUDA: {cuda_vals.sort(descending=True)[0][:10]},
|
||||||
|
Torch: {torch_vals.sort(descending=True)[0][:10]}"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
# ==================== CATEGORY: Sequence Length Edge Cases ====================
|
||||||
|
pytest.param(
|
||||||
|
{"seq_lens": [1, 10, 100, 2048], "top_k": 2048, "data_type": "random"},
|
||||||
|
id="seq_len_edge_very_small_to_medium",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"seq_lens": [2049, 2100, 2500, 3000],
|
||||||
|
"top_k": 2048,
|
||||||
|
"data_type": "random",
|
||||||
|
},
|
||||||
|
id="seq_len_edge_above_k",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"seq_lens": [8000, 16384, 20000], "top_k": 2048, "data_type": "random"},
|
||||||
|
id="algo_transition_filtered_radix",
|
||||||
|
),
|
||||||
|
# ==================== CATEGORY: Data Distributions ====================
|
||||||
|
pytest.param(
|
||||||
|
{"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "sorted_asc"},
|
||||||
|
id="data_sorted_ascending",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "sorted_desc"},
|
||||||
|
id="data_sorted_descending",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "all_same"},
|
||||||
|
id="data_all_same",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "many_ties"},
|
||||||
|
id="data_many_ties",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"seq_lens": [5000, 10000],
|
||||||
|
"top_k": 2048,
|
||||||
|
"data_type": "small_differences",
|
||||||
|
},
|
||||||
|
id="data_float_precision",
|
||||||
|
),
|
||||||
|
# ==================== CATEGORY: Alignment / Vectorization ====================
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"seq_lens": [2055, 2056, 2057, 2063],
|
||||||
|
"top_k": 2048,
|
||||||
|
"data_type": "random",
|
||||||
|
},
|
||||||
|
id="align_vec_boundaries_low",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"seq_lens": [4095, 4096, 4097, 4102],
|
||||||
|
"top_k": 2048,
|
||||||
|
"data_type": "random",
|
||||||
|
},
|
||||||
|
id="align_4k_boundary",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"seq_lens": [8191, 8192, 8193, 8198],
|
||||||
|
"top_k": 2048,
|
||||||
|
"data_type": "random",
|
||||||
|
},
|
||||||
|
id="align_8k_boundary",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"seq_lens": [16383, 16384, 16385, 16390],
|
||||||
|
"top_k": 2048,
|
||||||
|
"data_type": "random",
|
||||||
|
},
|
||||||
|
id="align_16k_boundary",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_persistent_topk_correctness(test_config: dict) -> None:
|
||||||
|
"""
|
||||||
|
Comprehensive correctness tests covering:
|
||||||
|
- Sequence length edge cases (trivial, boundary, varied)
|
||||||
|
- Very small sequences (< 100 elements)
|
||||||
|
- Mixed sequence lengths in same batch
|
||||||
|
- Data distributions (sorted, ties, precision)
|
||||||
|
- Memory alignment / vectorization boundaries
|
||||||
|
"""
|
||||||
|
run_large_context_topk_test(
|
||||||
|
batch_size=len(test_config["seq_lens"]),
|
||||||
|
seq_lens=test_config["seq_lens"],
|
||||||
|
top_k=test_config["top_k"],
|
||||||
|
data_type=test_config.get("data_type", "random"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
# ==================== CATEGORY: Batch Size Scalability ====================
|
||||||
|
pytest.param(
|
||||||
|
{"batch_size": 1, "seq_len": 5000, "top_k": 2048},
|
||||||
|
id="batch_1",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"batch_size": 4, "seq_len": 5000, "top_k": 2048},
|
||||||
|
id="batch_4",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"batch_size": 32, "seq_len": 5000, "top_k": 2048},
|
||||||
|
id="batch_32",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"batch_size": 256, "seq_len": 5000, "top_k": 2048},
|
||||||
|
id="batch_256",
|
||||||
|
),
|
||||||
|
# ==================== CATEGORY: Single-CTA vs Multi-CTA ====================
|
||||||
|
pytest.param(
|
||||||
|
{"batch_size": 2, "seq_len": 4096, "top_k": 2048},
|
||||||
|
id="single_cta_4k",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"batch_size": 2, "seq_len": 8192, "top_k": 2048},
|
||||||
|
id="single_cta_8k",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"batch_size": 2, "seq_len": 163840, "top_k": 2048},
|
||||||
|
id="multi_cta_163840_dsv3_max",
|
||||||
|
),
|
||||||
|
# ==================== CATEGORY: Extreme Cases ====================
|
||||||
|
pytest.param(
|
||||||
|
{"batch_size": 512, "seq_len": 5000, "top_k": 2048},
|
||||||
|
id="extreme_large_batch",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"batch_size": 2, "seq_len": 163840, "top_k": 2048},
|
||||||
|
id="extreme_dsv3_max_context",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_persistent_topk_algorithm_paths(test_config: dict) -> None:
|
||||||
|
"""
|
||||||
|
Test different algorithm execution paths (capped at 163840 for DeepSeek V3.2):
|
||||||
|
- Batch size scalability (1, 4, 32, 256)
|
||||||
|
- Single-CTA vs Multi-CTA execution
|
||||||
|
- Extreme configurations (large batch, max context length)
|
||||||
|
"""
|
||||||
|
run_large_context_topk_test(
|
||||||
|
batch_size=test_config["batch_size"],
|
||||||
|
seq_lens=[test_config["seq_len"]] * test_config["batch_size"],
|
||||||
|
top_k=test_config["top_k"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_persistent_topk_stress() -> None:
|
||||||
|
"""
|
||||||
|
Stress test with random configurations to catch edge cases.
|
||||||
|
Capped at 163840 (DeepSeek V3.2 max context) for realistic testing.
|
||||||
|
"""
|
||||||
|
torch.set_default_device("cuda:0")
|
||||||
|
top_k = 2048
|
||||||
|
|
||||||
|
for seed in range(3):
|
||||||
|
set_random_seed(seed)
|
||||||
|
|
||||||
|
# Random batch size (limited for speed)
|
||||||
|
batch_size = torch.randint(1, 32, (1,)).item()
|
||||||
|
|
||||||
|
# Random sequence lengths capped at DeepSeek V3.2 max context
|
||||||
|
seq_lens = torch.randint(100, 163840, (batch_size,)).tolist()
|
||||||
|
|
||||||
|
run_large_context_topk_test(
|
||||||
|
batch_size=batch_size,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
top_k=top_k,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
# Mixed batch: rows spanning all four paths (trivial, decode, medium, large)
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"seq_lens": [2000, 6000, 30000, 80000],
|
||||||
|
"top_k": 2048,
|
||||||
|
"data_type": "random",
|
||||||
|
},
|
||||||
|
id="mixed_all_paths",
|
||||||
|
),
|
||||||
|
# All decode/medium rows (typical decode scenario)
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"seq_lens": [2048, 4096, 8192, 16000],
|
||||||
|
"top_k": 2048,
|
||||||
|
"data_type": "random",
|
||||||
|
},
|
||||||
|
id="all_decode_medium",
|
||||||
|
),
|
||||||
|
# All large rows
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"seq_lens": [70000, 100000, 163840],
|
||||||
|
"top_k": 2048,
|
||||||
|
"data_type": "random",
|
||||||
|
},
|
||||||
|
id="all_large",
|
||||||
|
),
|
||||||
|
# Boundary around LARGE_THRESHOLD (32K)
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"seq_lens": [32767, 32768, 32769, 32772],
|
||||||
|
"top_k": 2048,
|
||||||
|
"data_type": "random",
|
||||||
|
},
|
||||||
|
id="large_threshold_boundary",
|
||||||
|
),
|
||||||
|
# Single row medium
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"seq_lens": [5000],
|
||||||
|
"top_k": 2048,
|
||||||
|
"data_type": "random",
|
||||||
|
},
|
||||||
|
id="single_row_medium",
|
||||||
|
),
|
||||||
|
# Single row large
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"seq_lens": [100000],
|
||||||
|
"top_k": 2048,
|
||||||
|
"data_type": "random",
|
||||||
|
},
|
||||||
|
id="single_row_large",
|
||||||
|
),
|
||||||
|
# Trivial rows mixed with medium and large
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"seq_lens": [100, 2048, 10000, 80000],
|
||||||
|
"top_k": 2048,
|
||||||
|
"data_type": "random",
|
||||||
|
},
|
||||||
|
id="trivial_medium_large_mix",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_persistent_topk(test_config: dict) -> None:
|
||||||
|
"""
|
||||||
|
Tests specific to the persistent_topk kernel:
|
||||||
|
- Mixed medium/large rows in the same batch (dynamic per-row dispatch)
|
||||||
|
- Boundary around LARGE_THRESHOLD (32K)
|
||||||
|
- Trivial + medium + large rows in a single batch
|
||||||
|
"""
|
||||||
|
run_large_context_topk_test(
|
||||||
|
batch_size=len(test_config["seq_lens"]),
|
||||||
|
seq_lens=test_config["seq_lens"],
|
||||||
|
top_k=test_config["top_k"],
|
||||||
|
data_type=test_config.get("data_type", "random"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_persistent_topk_padded_stride() -> None:
|
||||||
|
"""
|
||||||
|
Test persistent_topk with padded logits (large stride, small seq_len)
|
||||||
|
to simulate the e2e CUDAGraph scenario where fp8_paged_mqa_logits
|
||||||
|
returns [B, max_model_len] with max_model_len=163840.
|
||||||
|
"""
|
||||||
|
set_random_seed(42)
|
||||||
torch.set_default_device("cuda:0")
|
torch.set_default_device("cuda:0")
|
||||||
|
|
||||||
top_k = 2048
|
top_k = 2048
|
||||||
|
batch_size = 4
|
||||||
|
padded_stride = 163840 # DeepSeek-V3.2 max_model_len
|
||||||
|
actual_seq_lens = [3000, 5000, 8000, 12000]
|
||||||
|
|
||||||
# Test case 1: Short sequences (< 8192)
|
# Create padded logits tensor (like fp8_paged_mqa_logits output)
|
||||||
batch_size_short = 4
|
logits = torch.full(
|
||||||
next_n = 1
|
(batch_size, padded_stride),
|
||||||
num_rows_short = batch_size_short * next_n
|
float("-inf"),
|
||||||
|
dtype=torch.float32,
|
||||||
# Create sequences with max length < 8192
|
device="cuda",
|
||||||
seq_lens_short = torch.randint(
|
|
||||||
4000, 8000, (batch_size_short,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
)
|
||||||
|
for i, sl in enumerate(actual_seq_lens):
|
||||||
|
logits[i, :sl] = torch.randn(sl, dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
row_starts_short = torch.zeros(num_rows_short, dtype=torch.int32, device="cuda")
|
lengths = torch.tensor(actual_seq_lens, dtype=torch.int32, device="cuda")
|
||||||
row_indices_short = torch.arange(num_rows_short, device="cuda") // next_n
|
indices = torch.empty((batch_size, top_k), dtype=torch.int32, device="cuda")
|
||||||
next_n_offset_short = torch.arange(num_rows_short, device="cuda") % next_n
|
workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda")
|
||||||
row_ends_short = (
|
|
||||||
seq_lens_short[row_indices_short] - next_n + next_n_offset_short + 1
|
torch.ops._C.persistent_topk(
|
||||||
|
logits, lengths, indices, workspace, top_k, max(actual_seq_lens)
|
||||||
)
|
)
|
||||||
|
torch.accelerator.synchronize()
|
||||||
|
|
||||||
logits_short = create_random_logits(
|
# Validate against torch.topk
|
||||||
row_starts_short, row_ends_short, torch.float32, 42, clean_logits, "random"
|
for i in range(batch_size):
|
||||||
|
sl = actual_seq_lens[i]
|
||||||
|
k_i = min(top_k, sl)
|
||||||
|
expected = logits[i, :sl].topk(k_i, dim=-1)[1].cpu()
|
||||||
|
actual = indices[i, :k_i].cpu()
|
||||||
|
|
||||||
|
expected_set = set(expected.tolist())
|
||||||
|
actual_set = set(actual.tolist())
|
||||||
|
|
||||||
|
if expected_set != actual_set:
|
||||||
|
# Allow ties
|
||||||
|
expected_vals = logits[i, expected].cpu().sort(descending=True)[0]
|
||||||
|
actual_vals = logits[i, actual].cpu().sort(descending=True)[0]
|
||||||
|
assert torch.allclose(expected_vals, actual_vals, rtol=1e-4, atol=1e-4), (
|
||||||
|
f"Row {i}: persistent_topk with padded stride doesn't match. "
|
||||||
|
f"seq_len={sl}, stride={padded_stride}"
|
||||||
)
|
)
|
||||||
|
|
||||||
indices_vllm = torch.empty(
|
|
||||||
(num_rows_short, top_k), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use vllm's kernel for short sequences
|
|
||||||
torch.ops._C.top_k_per_row_decode(
|
|
||||||
logits_short,
|
|
||||||
next_n,
|
|
||||||
seq_lens_short,
|
|
||||||
indices_vllm,
|
|
||||||
num_rows_short,
|
|
||||||
logits_short.stride(0),
|
|
||||||
logits_short.stride(1),
|
|
||||||
top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test case 2: Long sequences (>= 8192) - should use large_context_topk kernel
|
|
||||||
batch_size_long = 4
|
|
||||||
num_rows_long = batch_size_long * next_n
|
|
||||||
|
|
||||||
# Create sequences with max length >= 8192
|
|
||||||
seq_lens_long = torch.randint(
|
|
||||||
8192, 16384, (batch_size_long,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
row_starts_long = torch.zeros(num_rows_long, dtype=torch.int32, device="cuda")
|
|
||||||
row_indices_long = torch.arange(num_rows_long, device="cuda") // next_n
|
|
||||||
next_n_offset_long = torch.arange(num_rows_long, device="cuda") % next_n
|
|
||||||
row_ends_long = seq_lens_long[row_indices_long] - next_n + next_n_offset_long + 1
|
|
||||||
|
|
||||||
logits_long = create_random_logits(
|
|
||||||
row_starts_long, row_ends_long, torch.float32, 43, clean_logits, "random"
|
|
||||||
)
|
|
||||||
|
|
||||||
indices = torch.empty((num_rows_long, top_k), dtype=torch.int32, device="cuda")
|
|
||||||
|
|
||||||
# Use large_context_topk kernel for long sequences
|
|
||||||
if next_n == 1:
|
|
||||||
lengths = seq_lens_long
|
|
||||||
else:
|
|
||||||
offsets = torch.arange(next_n, device=logits_long.device, dtype=torch.int32)
|
|
||||||
lengths = (seq_lens_long.unsqueeze(1) - next_n + 1 + offsets).flatten()
|
|
||||||
|
|
||||||
torch.ops._C.large_context_topk(
|
|
||||||
logits_long,
|
|
||||||
indices,
|
|
||||||
lengths,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_indices_short = torch.empty(
|
|
||||||
(num_rows_short, top_k), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
for i in range(num_rows_short):
|
|
||||||
row_end = int(row_ends_short[i])
|
|
||||||
k_i = min(top_k, row_end)
|
|
||||||
idx = logits_short[i, :row_end].topk(k_i, dim=-1)[1]
|
|
||||||
torch_indices_short[i, :k_i] = idx
|
|
||||||
|
|
||||||
assert compare_top_k_results(
|
|
||||||
logits_short,
|
|
||||||
indices_vllm,
|
|
||||||
torch_indices_short,
|
|
||||||
row_starts_short,
|
|
||||||
row_ends_short,
|
|
||||||
top_k,
|
|
||||||
), "top_k_per_row_decode kernel (short sequences) doesn't match torch.topk"
|
|
||||||
|
|
||||||
torch_indices_long = torch.empty(
|
|
||||||
(num_rows_long, top_k), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
for i in range(num_rows_long):
|
|
||||||
row_end = int(row_ends_long[i])
|
|
||||||
k_i = min(top_k, row_end)
|
|
||||||
idx = logits_long[i, :row_end].topk(k_i, dim=-1)[1]
|
|
||||||
torch_indices_long[i, :k_i] = idx
|
|
||||||
|
|
||||||
assert compare_top_k_results(
|
|
||||||
logits_long, indices, torch_indices_long, row_starts_long, row_ends_long, top_k
|
|
||||||
), "large_context_topk kernel (long sequences) doesn't match torch.topk"
|
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ elif current_platform.is_xpu():
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
def sparse_attn_indexer(
|
def sparse_attn_indexer(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -51,6 +53,7 @@ def sparse_attn_indexer(
|
|||||||
current_workspace_manager().get_simultaneous(
|
current_workspace_manager().get_simultaneous(
|
||||||
((total_seq_lens, head_dim), torch.float8_e4m3fn),
|
((total_seq_lens, head_dim), torch.float8_e4m3fn),
|
||||||
((total_seq_lens, 4), torch.uint8),
|
((total_seq_lens, 4), torch.uint8),
|
||||||
|
((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Dummy allocation to simulate for peak logits tensor memory during inference.
|
# Dummy allocation to simulate for peak logits tensor memory during inference.
|
||||||
@@ -157,15 +160,6 @@ def sparse_attn_indexer(
|
|||||||
topk_tokens,
|
topk_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute lengths from row spans
|
|
||||||
# lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32)
|
|
||||||
# torch.ops._C.large_context_topk(
|
|
||||||
# logits,
|
|
||||||
# topk_indices,
|
|
||||||
# lengths,
|
|
||||||
# chunk.cu_seqlen_ks, # row_starts
|
|
||||||
# )
|
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
decode_metadata = attn_metadata.decode
|
decode_metadata = attn_metadata.decode
|
||||||
assert decode_metadata is not None
|
assert decode_metadata is not None
|
||||||
@@ -204,7 +198,6 @@ def sparse_attn_indexer(
|
|||||||
num_rows = logits.shape[0]
|
num_rows = logits.shape[0]
|
||||||
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
|
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
|
||||||
|
|
||||||
if decode_metadata.use_large_context_topk:
|
|
||||||
if next_n == 1:
|
if next_n == 1:
|
||||||
lengths = decode_metadata.seq_lens
|
lengths = decode_metadata.seq_lens
|
||||||
else:
|
else:
|
||||||
@@ -216,11 +209,18 @@ def sparse_attn_indexer(
|
|||||||
+ decode_metadata.offsets
|
+ decode_metadata.offsets
|
||||||
).flatten()
|
).flatten()
|
||||||
|
|
||||||
torch.ops._C.large_context_topk(
|
if current_platform.is_cuda():
|
||||||
|
workspace_manager = current_workspace_manager()
|
||||||
|
(topk_workspace,) = workspace_manager.get_simultaneous(
|
||||||
|
((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),
|
||||||
|
)
|
||||||
|
torch.ops._C.persistent_topk(
|
||||||
logits,
|
logits,
|
||||||
topk_indices,
|
|
||||||
lengths,
|
lengths,
|
||||||
None,
|
topk_indices,
|
||||||
|
topk_workspace,
|
||||||
|
topk_tokens,
|
||||||
|
attn_metadata.max_seq_len,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if current_platform.is_xpu():
|
if current_platform.is_xpu():
|
||||||
|
|||||||
@@ -67,7 +67,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
|
from vllm.model_executor.layers.sparse_attn_indexer import (
|
||||||
|
SparseAttnIndexer,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
@@ -1203,7 +1205,9 @@ class DeepseekV2Model(nn.Module):
|
|||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: DeepseekV2DecoderLayer(
|
lambda prefix: DeepseekV2DecoderLayer(
|
||||||
vllm_config, prefix, topk_indices_buffer=topk_indices_buffer
|
vllm_config,
|
||||||
|
prefix,
|
||||||
|
topk_indices_buffer=topk_indices_buffer,
|
||||||
),
|
),
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -145,7 +145,6 @@ class DeepSeekV32IndexerDecodeMetadata:
|
|||||||
decode_lens: torch.Tensor
|
decode_lens: torch.Tensor
|
||||||
requires_padding: bool
|
requires_padding: bool
|
||||||
schedule_metadata: torch.Tensor
|
schedule_metadata: torch.Tensor
|
||||||
use_large_context_topk: bool
|
|
||||||
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
|
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
|
||||||
|
|
||||||
|
|
||||||
@@ -437,7 +436,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
|||||||
|
|
||||||
if use_native and next_n > 1:
|
if use_native and next_n > 1:
|
||||||
offsets = self.offsets_buffer
|
offsets = self.offsets_buffer
|
||||||
batch_size = num_decodes
|
|
||||||
elif max_decode_len > 1:
|
elif max_decode_len > 1:
|
||||||
# Flatten multi-token decode requests into single-token
|
# Flatten multi-token decode requests into single-token
|
||||||
# batch entries, expanding seq_lens and block tables so
|
# batch entries, expanding seq_lens and block tables so
|
||||||
@@ -496,10 +494,8 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
|||||||
self.decode_lens_buffer[:num_decode_tokens] = 1
|
self.decode_lens_buffer[:num_decode_tokens] = 1
|
||||||
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
|
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
|
||||||
offsets = None
|
offsets = None
|
||||||
batch_size = num_decode_tokens
|
|
||||||
else:
|
else:
|
||||||
offsets = None
|
offsets = None
|
||||||
batch_size = num_decodes
|
|
||||||
|
|
||||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||||
if current_platform.is_cuda() and has_deep_gemm():
|
if current_platform.is_cuda() and has_deep_gemm():
|
||||||
@@ -509,20 +505,12 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
|||||||
self.num_sms,
|
self.num_sms,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Decide which top-k kernel to use based on batch size and sequence length
|
|
||||||
# Decision logic based on micro-benchmark results:
|
|
||||||
# - large_context_topk wins for batch <= 128 and seq_len > 8K
|
|
||||||
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
|
|
||||||
_is_large_context = common_attn_metadata.max_seq_len > 8192
|
|
||||||
use_large_context_topk = batch_size <= 128 and _is_large_context
|
|
||||||
|
|
||||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||||
block_table=block_table,
|
block_table=block_table,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
decode_lens=decode_lens,
|
decode_lens=decode_lens,
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
schedule_metadata=self.scheduler_metadata_buffer,
|
schedule_metadata=self.scheduler_metadata_buffer,
|
||||||
use_large_context_topk=use_large_context_topk,
|
|
||||||
offsets=offsets,
|
offsets=offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user