[Refactor] Improve indexer decode path metadata preparation (#38865)

This commit is contained in:
Yongye Zhu
2026-04-08 23:49:15 -04:00
committed by GitHub
parent ef5a226819
commit 2e98406048
4 changed files with 162 additions and 102 deletions

View File

@@ -564,8 +564,9 @@ template <int kNumThreadsPerBlock, bool useRadixSort,
bool multipleBlocksPerRow = false, bool mergeBlocks = false>
static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
const float* logits, const int* seqLens, int* outIndices, int stride0,
int stride1, const int topK, int next_n, float* outLogits = nullptr,
const int numBlocksToMerge = 0, const int* indices = nullptr) {
int stride1, const int topK, int next_n, int seqLensIs2D = 0,
float* outLogits = nullptr, const int numBlocksToMerge = 0,
const int* indices = nullptr) {
// The number of bins in the histogram.
static constexpr int kNumBins = 2048;
@@ -574,8 +575,16 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
// The range of logits within the row.
int rowStart = 0;
int seq_len = seqLens[rowIdx / next_n];
int rowEnd = max(0, seq_len - next_n + (rowIdx % next_n) + 1);
int batch_idx = rowIdx / next_n;
int next_n_idx = rowIdx % next_n;
// seqLensIs2D=0: 1D seqLens — all rows in a batch share the same seq_len;
// kernel computes per-row effective length via offset.
// seqLensIs2D=1: 2D seqLens — each logit row has its own pre-computed
// effective length (flat index rowIdx = b*next_n + j maps
// directly to seqLens[b, j] in C-contiguous layout).
int seq_len = seqLensIs2D ? seqLens[rowIdx] : seqLens[batch_idx];
int rowEnd =
seqLensIs2D ? max(0, seq_len) : max(0, seq_len - next_n + next_n_idx + 1);
// Local pointers to this block
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
@@ -653,6 +662,11 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto numColumns = logits.size(1);
// True if seqLens is 2D (B, next_n): each logit row has its own pre-computed
// effective seq_len. False if seqLens is 1D (B,): all rows in a batch share
// the same seq_len and the kernel computes the per-row offset itself.
int seqLensIs2D = seqLens.dim() == 2 ? 1 : 0;
if (numColumns < kSortingAlgorithmThreshold) {
// Use insertion sort
vllm::topKPerRowDecode<kNumThreadsPerBlock, false>
@@ -660,7 +674,7 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n));
static_cast<int>(next_n), seqLensIs2D);
} else if (numColumns < kSplitWorkThreshold) {
// From this threshold, use radix sort instead
vllm::topKPerRowDecode<kNumThreadsPerBlock, true>
@@ -668,7 +682,7 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n));
static_cast<int>(next_n), seqLensIs2D);
} else {
// Long sequences are run in two steps
constexpr auto multipleBlocksPerRowConfig = 10;
@@ -686,15 +700,16 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
outIndicesAux.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n), outLogitsAux.data_ptr<float>());
static_cast<int>(next_n), seqLensIs2D,
outLogitsAux.data_ptr<float>());
constexpr int kNumThreadsPerBlockMerge = 1024;
vllm::topKPerRowDecode<kNumThreadsPerBlockMerge, true, false, true>
<<<numRows, kNumThreadsPerBlockMerge, topK * sizeof(int32_t), stream>>>(
outLogitsAux.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), multipleBlocksPerRowConfig * topK, 1,
static_cast<int>(topK), static_cast<int>(next_n), nullptr,
multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
static_cast<int>(topK), static_cast<int>(next_n), seqLensIs2D,
nullptr, multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
}
}

View File

@@ -21,13 +21,15 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
TORCH_CHECK(lengths.dim() == 1, "lengths must be 1D");
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
"lengths must be 1D or 2D");
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
TORCH_CHECK(output.dim() == 2, "output must be 2D");
const int64_t num_rows = logits.size(0);
const int64_t stride = logits.size(1);
TORCH_CHECK(lengths.size(0) == num_rows, "lengths size mismatch");
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
"output size mismatch");
namespace P = vllm::persistent;

View File

@@ -183,13 +183,15 @@ def sparse_attn_indexer(
# TODO: move and optimize below logic with triton kernels
batch_size = padded_q_fp8_decode_tokens.shape[0]
next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
seq_lens = decode_metadata.seq_lens[:batch_size]
# seq_lens is (B, next_n) for native spec decode, (B,) otherwise.
# fp8_paged_mqa_logits and all topk kernels accept both shapes.
logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
@@ -198,17 +200,6 @@ def sparse_attn_indexer(
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
if next_n == 1:
lengths = decode_metadata.seq_lens
else:
# (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,)
lengths = (
decode_metadata.seq_lens.unsqueeze(1)
- next_n
+ 1
+ decode_metadata.offsets
).flatten()
if current_platform.is_cuda():
workspace_manager = current_workspace_manager()
(topk_workspace,) = workspace_manager.get_simultaneous(
@@ -216,7 +207,7 @@ def sparse_attn_indexer(
)
torch.ops._C.persistent_topk(
logits,
lengths,
decode_metadata.seq_lens,
topk_indices,
topk_workspace,
topk_tokens,
@@ -227,7 +218,7 @@ def sparse_attn_indexer(
ops.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
seq_lens,
topk_indices,
num_rows,
logits.stride(0),
@@ -238,7 +229,7 @@ def sparse_attn_indexer(
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
seq_lens,
topk_indices,
num_rows,
logits.stride(0),

View File

@@ -141,11 +141,14 @@ class DeepseekV32IndexerPrefillMetadata:
@dataclass
class DeepSeekV32IndexerDecodeMetadata:
block_table: torch.Tensor
# seq_lens: per-token effective context lengths.
# - flatten path / plain decode: 1D (batch_size,)
# - native MTP path: 2D (B, next_n) where [b,j] = L_b - next_n + j + 1
# Both fp8_paged_mqa_logits and the topk kernels accept both shapes.
seq_lens: torch.Tensor
decode_lens: torch.Tensor
requires_padding: bool
schedule_metadata: torch.Tensor
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
@dataclass
@@ -283,21 +286,31 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
sm_count = num_compute_units(self.device.index)
self.num_sms = sm_count
self.decode_lens_buffer = torch.empty(
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
self.offsets_buffer = torch.arange(
next_n, device=self.device, dtype=torch.int32
)
self.arange_buffer = torch.arange(
scheduler_config.max_num_seqs * next_n,
self.decode_lens_buffer = torch.zeros(
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
self.expanded_seq_lens_buffer = torch.zeros(
(scheduler_config.max_num_batched_tokens,),
if not self.use_flattening and next_n > 1:
# Native MTP: 2D buffer for per-token seq_lens.
# Flattening path is never used, so no expanded_seq_lens_buffer.
self.decode_seq_lens_buffer = torch.zeros(
(scheduler_config.max_num_seqs, next_n),
dtype=torch.int32,
device=self.device,
)
else:
# Flattening or no MTP: 1D buffer for expanded per-token seq_lens.
self.decode_seq_lens_buffer = torch.zeros(
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
self.arange_buffer = torch.arange(
scheduler_config.max_num_seqs * next_n,
dtype=torch.int32,
device=self.device,
)
@@ -367,6 +380,96 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
skip_kv_gather=skip_kv_gather,
)
def _prepare_decode_tensors(
self,
seq_lens: torch.Tensor,
block_table: torch.Tensor,
decode_lens: torch.Tensor,
decode_lens_cpu: torch.Tensor,
query_start_loc: torch.Tensor,
num_decodes: int,
num_decode_tokens: int,
use_native: bool,
next_n: int,
max_decode_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, bool]:
"""Expand seq_lens/block_table/decode_lens for the decode kernels.
Flatten path (not use_native, max_decode_len > 1):
Each multi-token decode request is expanded into individual
single-token entries so the kernel always sees next_n=1.
Native path (use_native or max_decode_len == 1):
Plain decode or spec-decode with 2D per-token context lengths.
Returns (seq_lens, block_table, decode_lens, batch_size, requires_padding).
seq_lens is 1D (batch_size,) for flatten/plain, 2D (B, next_n) for native MTP.
"""
if not use_native and max_decode_len > 1:
assert self.decode_seq_lens_buffer.dim() == 1
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded = int(decode_lens_cpu.sum().item())
# Fuse expanded_base and expanded_starts into a single repeat_interleave:
# seq_len_i = (context_start[b] - query_start_loc[b]) + arange[i] + 1
# where context_start[b] = seq_lens[b] - decode_lens[b].
# Example: offsets = [7-0, 6-3, 8-4, 0-8] = [7, 3, 4, -8]
# expanded_offsets = [7, 7, 7, 3, 4, 4, 4, 4]
# result = [8, 9, 10, 7, 9, 10, 11, 12]
expanded_offsets = torch.repeat_interleave(
seq_lens - decode_lens - query_start_loc,
decode_lens,
output_size=actual_expanded,
)
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self.decode_seq_lens_buffer[:actual_expanded] = (
expanded_offsets + self.arange_buffer[:actual_expanded] + 1
)
self.decode_seq_lens_buffer[actual_expanded:] = 0
seq_lens = self.decode_seq_lens_buffer[:num_decode_tokens]
# Give each of the flattened entries the same block table row as the
# original request.
self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(
block_table, decode_lens, dim=0, output_size=actual_expanded
)
)
if actual_expanded < num_decode_tokens:
self.expanded_block_table_buffer[
actual_expanded:num_decode_tokens, 0
] = 0
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
# All reqs now have decode_len=1
self.decode_lens_buffer[:num_decode_tokens] = 1
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
return seq_lens, block_table, decode_lens, num_decode_tokens, False
else:
# Native path: plain decode (next_n==1) or spec decode
# with 2D per-token context lengths (next_n > 1).
#
# When decode_lens are not truly uniform (e.g. some requests have
# decode_len < next_n due to padding or short prefills), the simple
# reshape in sparse_attn_indexer won't work. Use pack_seq_triton
# (requires_padding) instead.
min_decode_len = int(decode_lens_cpu.min().item())
requires_padding = min_decode_len != max_decode_len
if use_native and next_n > 1:
assert self.decode_seq_lens_buffer.dim() == 2
# (B, next_n): token j attends to L - next_n + j + 1 KV tokens
self.decode_seq_lens_buffer[:num_decodes] = (
seq_lens.unsqueeze(1) - next_n + 1 + self.offsets_buffer
)
seq_lens = self.decode_seq_lens_buffer[:num_decodes]
return seq_lens, block_table, decode_lens, num_decodes, requires_padding
def build(
self,
common_prefix_len: int,
@@ -434,68 +537,20 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
next_n = 1 + self.num_speculative_tokens
use_native = not self.use_flattening and max_decode_len == next_n
if use_native and next_n > 1:
offsets = self.offsets_buffer
elif max_decode_len > 1:
# Flatten multi-token decode requests into single-token
# batch entries, expanding seq_lens and block tables so
# the kernel always sees next_n=1.
# Also handles the edge case where use_flattening=False
# but max_decode_len != next_n (e.g. a batch containing some
# short prefills (q_len < next_n) and no true decodes).
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded = int(decode_lens_cpu.sum().item())
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
expanded_base = torch.repeat_interleave(
seq_lens - decode_lens, decode_lens, output_size=actual_expanded
seq_lens, block_table, decode_lens, batch_size, requires_padding = (
self._prepare_decode_tensors(
seq_lens=seq_lens,
block_table=block_table,
decode_lens=decode_lens,
decode_lens_cpu=decode_lens_cpu,
query_start_loc=common_attn_metadata.query_start_loc[:num_decodes],
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
use_native=use_native,
next_n=next_n,
max_decode_len=max_decode_len,
)
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
expanded_starts = torch.repeat_interleave(
common_attn_metadata.query_start_loc[:num_decodes],
decode_lens,
output_size=actual_expanded,
)
# [0, 1, 2, 0, 0, 1, 2, 3]
positions_within = (
self.arange_buffer[:actual_expanded] - expanded_starts
)
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self.expanded_seq_lens_buffer[:actual_expanded] = (
expanded_base + positions_within + 1
)
self.expanded_seq_lens_buffer[actual_expanded:] = 0
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
# Give each of the flattened entries the same block table row as the
# original request.
self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(
block_table, decode_lens, dim=0, output_size=actual_expanded
)
)
if actual_expanded < num_decode_tokens:
self.expanded_block_table_buffer[
actual_expanded:num_decode_tokens, 0
] = 0
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
# All reqs now have decode_len=1
self.decode_lens_buffer[:num_decode_tokens] = 1
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
offsets = None
else:
offsets = None
)
# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and has_deep_gemm():
@@ -509,9 +564,8 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
block_table=block_table,
seq_lens=seq_lens,
decode_lens=decode_lens,
requires_padding=False,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
offsets=offsets,
)
attn_metadata = DeepseekV32IndexerMetadata(
@@ -531,6 +585,4 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
decode=decode_metadata,
)
# if get_tensor_model_parallel_rank() == 0:
# logger.info(f"attn_metadata: {attn_metadata}")
return attn_metadata