[Refactor] Improve indexer decode path metadata preparation (#38865)

This commit is contained in:
Yongye Zhu
2026-04-08 23:49:15 -04:00
committed by GitHub
parent ef5a226819
commit 2e98406048
4 changed files with 162 additions and 102 deletions

View File

@@ -564,8 +564,9 @@ template <int kNumThreadsPerBlock, bool useRadixSort,
bool multipleBlocksPerRow = false, bool mergeBlocks = false>
static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
const float* logits, const int* seqLens, int* outIndices, int stride0,
int stride1, const int topK, int next_n, float* outLogits = nullptr,
const int numBlocksToMerge = 0, const int* indices = nullptr) {
int stride1, const int topK, int next_n, int seqLensIs2D = 0,
float* outLogits = nullptr, const int numBlocksToMerge = 0,
const int* indices = nullptr) {
// The number of bins in the histogram.
static constexpr int kNumBins = 2048;
@@ -574,8 +575,16 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
// The range of logits within the row.
int rowStart = 0;
int seq_len = seqLens[rowIdx / next_n];
int rowEnd = max(0, seq_len - next_n + (rowIdx % next_n) + 1);
int batch_idx = rowIdx / next_n;
int next_n_idx = rowIdx % next_n;
// seqLensIs2D=0: 1D seqLens — all rows in a batch share the same seq_len;
// kernel computes per-row effective length via offset.
// seqLensIs2D=1: 2D seqLens — each logit row has its own pre-computed
// effective length (flat index rowIdx = b*next_n + j maps
// directly to seqLens[b, j] in C-contiguous layout).
int seq_len = seqLensIs2D ? seqLens[rowIdx] : seqLens[batch_idx];
int rowEnd =
seqLensIs2D ? max(0, seq_len) : max(0, seq_len - next_n + next_n_idx + 1);
// Local pointers to this block
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
@@ -653,6 +662,11 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto numColumns = logits.size(1);
// True if seqLens is 2D (B, next_n): each logit row has its own pre-computed
// effective seq_len. False if seqLens is 1D (B,): all rows in a batch share
// the same seq_len and the kernel computes the per-row offset itself.
int seqLensIs2D = seqLens.dim() == 2 ? 1 : 0;
if (numColumns < kSortingAlgorithmThreshold) {
// Use insertion sort
vllm::topKPerRowDecode<kNumThreadsPerBlock, false>
@@ -660,7 +674,7 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n));
static_cast<int>(next_n), seqLensIs2D);
} else if (numColumns < kSplitWorkThreshold) {
// From this threshold, use radix sort instead
vllm::topKPerRowDecode<kNumThreadsPerBlock, true>
@@ -668,7 +682,7 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n));
static_cast<int>(next_n), seqLensIs2D);
} else {
// Long sequences are run in two steps
constexpr auto multipleBlocksPerRowConfig = 10;
@@ -686,15 +700,16 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
outIndicesAux.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n), outLogitsAux.data_ptr<float>());
static_cast<int>(next_n), seqLensIs2D,
outLogitsAux.data_ptr<float>());
constexpr int kNumThreadsPerBlockMerge = 1024;
vllm::topKPerRowDecode<kNumThreadsPerBlockMerge, true, false, true>
<<<numRows, kNumThreadsPerBlockMerge, topK * sizeof(int32_t), stream>>>(
outLogitsAux.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), multipleBlocksPerRowConfig * topK, 1,
static_cast<int>(topK), static_cast<int>(next_n), nullptr,
multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
static_cast<int>(topK), static_cast<int>(next_n), seqLensIs2D,
nullptr, multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
}
}

View File

@@ -21,13 +21,15 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
TORCH_CHECK(lengths.dim() == 1, "lengths must be 1D");
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
"lengths must be 1D or 2D");
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
TORCH_CHECK(output.dim() == 2, "output must be 2D");
const int64_t num_rows = logits.size(0);
const int64_t stride = logits.size(1);
TORCH_CHECK(lengths.size(0) == num_rows, "lengths size mismatch");
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
"output size mismatch");
namespace P = vllm::persistent;