[Refactor] Improve indexer decode path metadata preparation (#38865)
This commit is contained in:
@@ -564,8 +564,9 @@ template <int kNumThreadsPerBlock, bool useRadixSort,
|
||||
bool multipleBlocksPerRow = false, bool mergeBlocks = false>
|
||||
static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
||||
const float* logits, const int* seqLens, int* outIndices, int stride0,
|
||||
int stride1, const int topK, int next_n, float* outLogits = nullptr,
|
||||
const int numBlocksToMerge = 0, const int* indices = nullptr) {
|
||||
int stride1, const int topK, int next_n, int seqLensIs2D = 0,
|
||||
float* outLogits = nullptr, const int numBlocksToMerge = 0,
|
||||
const int* indices = nullptr) {
|
||||
// The number of bins in the histogram.
|
||||
static constexpr int kNumBins = 2048;
|
||||
|
||||
@@ -574,8 +575,16 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
||||
|
||||
// The range of logits within the row.
|
||||
int rowStart = 0;
|
||||
int seq_len = seqLens[rowIdx / next_n];
|
||||
int rowEnd = max(0, seq_len - next_n + (rowIdx % next_n) + 1);
|
||||
int batch_idx = rowIdx / next_n;
|
||||
int next_n_idx = rowIdx % next_n;
|
||||
// seqLensIs2D=0: 1D seqLens — all rows in a batch share the same seq_len;
|
||||
// kernel computes per-row effective length via offset.
|
||||
// seqLensIs2D=1: 2D seqLens — each logit row has its own pre-computed
|
||||
// effective length (flat index rowIdx = b*next_n + j maps
|
||||
// directly to seqLens[b, j] in C-contiguous layout).
|
||||
int seq_len = seqLensIs2D ? seqLens[rowIdx] : seqLens[batch_idx];
|
||||
int rowEnd =
|
||||
seqLensIs2D ? max(0, seq_len) : max(0, seq_len - next_n + next_n_idx + 1);
|
||||
|
||||
// Local pointers to this block
|
||||
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
|
||||
@@ -653,6 +662,11 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const auto numColumns = logits.size(1);
|
||||
|
||||
// True if seqLens is 2D (B, next_n): each logit row has its own pre-computed
|
||||
// effective seq_len. False if seqLens is 1D (B,): all rows in a batch share
|
||||
// the same seq_len and the kernel computes the per-row offset itself.
|
||||
int seqLensIs2D = seqLens.dim() == 2 ? 1 : 0;
|
||||
|
||||
if (numColumns < kSortingAlgorithmThreshold) {
|
||||
// Use insertion sort
|
||||
vllm::topKPerRowDecode<kNumThreadsPerBlock, false>
|
||||
@@ -660,7 +674,7 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
indices.data_ptr<int>(), static_cast<int>(stride0),
|
||||
static_cast<int>(stride1), static_cast<int>(topK),
|
||||
static_cast<int>(next_n));
|
||||
static_cast<int>(next_n), seqLensIs2D);
|
||||
} else if (numColumns < kSplitWorkThreshold) {
|
||||
// From this threshold, use radix sort instead
|
||||
vllm::topKPerRowDecode<kNumThreadsPerBlock, true>
|
||||
@@ -668,7 +682,7 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
indices.data_ptr<int>(), static_cast<int>(stride0),
|
||||
static_cast<int>(stride1), static_cast<int>(topK),
|
||||
static_cast<int>(next_n));
|
||||
static_cast<int>(next_n), seqLensIs2D);
|
||||
} else {
|
||||
// Long sequences are run in two steps
|
||||
constexpr auto multipleBlocksPerRowConfig = 10;
|
||||
@@ -686,15 +700,16 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
outIndicesAux.data_ptr<int>(), static_cast<int>(stride0),
|
||||
static_cast<int>(stride1), static_cast<int>(topK),
|
||||
static_cast<int>(next_n), outLogitsAux.data_ptr<float>());
|
||||
static_cast<int>(next_n), seqLensIs2D,
|
||||
outLogitsAux.data_ptr<float>());
|
||||
|
||||
constexpr int kNumThreadsPerBlockMerge = 1024;
|
||||
vllm::topKPerRowDecode<kNumThreadsPerBlockMerge, true, false, true>
|
||||
<<<numRows, kNumThreadsPerBlockMerge, topK * sizeof(int32_t), stream>>>(
|
||||
outLogitsAux.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
indices.data_ptr<int>(), multipleBlocksPerRowConfig * topK, 1,
|
||||
static_cast<int>(topK), static_cast<int>(next_n), nullptr,
|
||||
multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
|
||||
static_cast<int>(topK), static_cast<int>(next_n), seqLensIs2D,
|
||||
nullptr, multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,13 +21,15 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
|
||||
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
|
||||
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
|
||||
TORCH_CHECK(lengths.dim() == 1, "lengths must be 1D");
|
||||
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
|
||||
"lengths must be 1D or 2D");
|
||||
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
|
||||
TORCH_CHECK(output.dim() == 2, "output must be 2D");
|
||||
|
||||
const int64_t num_rows = logits.size(0);
|
||||
const int64_t stride = logits.size(1);
|
||||
|
||||
TORCH_CHECK(lengths.size(0) == num_rows, "lengths size mismatch");
|
||||
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
|
||||
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
|
||||
"output size mismatch");
|
||||
namespace P = vllm::persistent;
|
||||
|
||||
Reference in New Issue
Block a user