[Refactor] Improve indexer decode path metadata preparation (#38865)
This commit is contained in:
@@ -564,8 +564,9 @@ template <int kNumThreadsPerBlock, bool useRadixSort,
|
||||
bool multipleBlocksPerRow = false, bool mergeBlocks = false>
|
||||
static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
||||
const float* logits, const int* seqLens, int* outIndices, int stride0,
|
||||
int stride1, const int topK, int next_n, float* outLogits = nullptr,
|
||||
const int numBlocksToMerge = 0, const int* indices = nullptr) {
|
||||
int stride1, const int topK, int next_n, int seqLensIs2D = 0,
|
||||
float* outLogits = nullptr, const int numBlocksToMerge = 0,
|
||||
const int* indices = nullptr) {
|
||||
// The number of bins in the histogram.
|
||||
static constexpr int kNumBins = 2048;
|
||||
|
||||
@@ -574,8 +575,16 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
||||
|
||||
// The range of logits within the row.
|
||||
int rowStart = 0;
|
||||
int seq_len = seqLens[rowIdx / next_n];
|
||||
int rowEnd = max(0, seq_len - next_n + (rowIdx % next_n) + 1);
|
||||
int batch_idx = rowIdx / next_n;
|
||||
int next_n_idx = rowIdx % next_n;
|
||||
// seqLensIs2D=0: 1D seqLens — all rows in a batch share the same seq_len;
|
||||
// kernel computes per-row effective length via offset.
|
||||
// seqLensIs2D=1: 2D seqLens — each logit row has its own pre-computed
|
||||
// effective length (flat index rowIdx = b*next_n + j maps
|
||||
// directly to seqLens[b, j] in C-contiguous layout).
|
||||
int seq_len = seqLensIs2D ? seqLens[rowIdx] : seqLens[batch_idx];
|
||||
int rowEnd =
|
||||
seqLensIs2D ? max(0, seq_len) : max(0, seq_len - next_n + next_n_idx + 1);
|
||||
|
||||
// Local pointers to this block
|
||||
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
|
||||
@@ -653,6 +662,11 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const auto numColumns = logits.size(1);
|
||||
|
||||
// True if seqLens is 2D (B, next_n): each logit row has its own pre-computed
|
||||
// effective seq_len. False if seqLens is 1D (B,): all rows in a batch share
|
||||
// the same seq_len and the kernel computes the per-row offset itself.
|
||||
int seqLensIs2D = seqLens.dim() == 2 ? 1 : 0;
|
||||
|
||||
if (numColumns < kSortingAlgorithmThreshold) {
|
||||
// Use insertion sort
|
||||
vllm::topKPerRowDecode<kNumThreadsPerBlock, false>
|
||||
@@ -660,7 +674,7 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
indices.data_ptr<int>(), static_cast<int>(stride0),
|
||||
static_cast<int>(stride1), static_cast<int>(topK),
|
||||
static_cast<int>(next_n));
|
||||
static_cast<int>(next_n), seqLensIs2D);
|
||||
} else if (numColumns < kSplitWorkThreshold) {
|
||||
// From this threshold, use radix sort instead
|
||||
vllm::topKPerRowDecode<kNumThreadsPerBlock, true>
|
||||
@@ -668,7 +682,7 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
indices.data_ptr<int>(), static_cast<int>(stride0),
|
||||
static_cast<int>(stride1), static_cast<int>(topK),
|
||||
static_cast<int>(next_n));
|
||||
static_cast<int>(next_n), seqLensIs2D);
|
||||
} else {
|
||||
// Long sequences are run in two steps
|
||||
constexpr auto multipleBlocksPerRowConfig = 10;
|
||||
@@ -686,15 +700,16 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
outIndicesAux.data_ptr<int>(), static_cast<int>(stride0),
|
||||
static_cast<int>(stride1), static_cast<int>(topK),
|
||||
static_cast<int>(next_n), outLogitsAux.data_ptr<float>());
|
||||
static_cast<int>(next_n), seqLensIs2D,
|
||||
outLogitsAux.data_ptr<float>());
|
||||
|
||||
constexpr int kNumThreadsPerBlockMerge = 1024;
|
||||
vllm::topKPerRowDecode<kNumThreadsPerBlockMerge, true, false, true>
|
||||
<<<numRows, kNumThreadsPerBlockMerge, topK * sizeof(int32_t), stream>>>(
|
||||
outLogitsAux.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
indices.data_ptr<int>(), multipleBlocksPerRowConfig * topK, 1,
|
||||
static_cast<int>(topK), static_cast<int>(next_n), nullptr,
|
||||
multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
|
||||
static_cast<int>(topK), static_cast<int>(next_n), seqLensIs2D,
|
||||
nullptr, multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,13 +21,15 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
|
||||
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
|
||||
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
|
||||
TORCH_CHECK(lengths.dim() == 1, "lengths must be 1D");
|
||||
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
|
||||
"lengths must be 1D or 2D");
|
||||
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
|
||||
TORCH_CHECK(output.dim() == 2, "output must be 2D");
|
||||
|
||||
const int64_t num_rows = logits.size(0);
|
||||
const int64_t stride = logits.size(1);
|
||||
|
||||
TORCH_CHECK(lengths.size(0) == num_rows, "lengths size mismatch");
|
||||
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
|
||||
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
|
||||
"output size mismatch");
|
||||
namespace P = vllm::persistent;
|
||||
|
||||
@@ -183,13 +183,15 @@ def sparse_attn_indexer(
|
||||
# TODO: move and optimize below logic with triton kernels
|
||||
batch_size = padded_q_fp8_decode_tokens.shape[0]
|
||||
next_n = padded_q_fp8_decode_tokens.shape[1]
|
||||
assert batch_size == decode_metadata.seq_lens.shape[0]
|
||||
num_padded_tokens = batch_size * next_n
|
||||
seq_lens = decode_metadata.seq_lens[:batch_size]
|
||||
# seq_lens is (B, next_n) for native spec decode, (B,) otherwise.
|
||||
# fp8_paged_mqa_logits and all topk kernels accept both shapes.
|
||||
logits = fp8_paged_mqa_logits(
|
||||
padded_q_fp8_decode_tokens,
|
||||
kv_cache,
|
||||
weights[:num_padded_tokens],
|
||||
decode_metadata.seq_lens,
|
||||
seq_lens,
|
||||
decode_metadata.block_table,
|
||||
decode_metadata.schedule_metadata,
|
||||
max_model_len=max_model_len,
|
||||
@@ -198,17 +200,6 @@ def sparse_attn_indexer(
|
||||
num_rows = logits.shape[0]
|
||||
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
|
||||
|
||||
if next_n == 1:
|
||||
lengths = decode_metadata.seq_lens
|
||||
else:
|
||||
# (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,)
|
||||
lengths = (
|
||||
decode_metadata.seq_lens.unsqueeze(1)
|
||||
- next_n
|
||||
+ 1
|
||||
+ decode_metadata.offsets
|
||||
).flatten()
|
||||
|
||||
if current_platform.is_cuda():
|
||||
workspace_manager = current_workspace_manager()
|
||||
(topk_workspace,) = workspace_manager.get_simultaneous(
|
||||
@@ -216,7 +207,7 @@ def sparse_attn_indexer(
|
||||
)
|
||||
torch.ops._C.persistent_topk(
|
||||
logits,
|
||||
lengths,
|
||||
decode_metadata.seq_lens,
|
||||
topk_indices,
|
||||
topk_workspace,
|
||||
topk_tokens,
|
||||
@@ -227,7 +218,7 @@ def sparse_attn_indexer(
|
||||
ops.top_k_per_row_decode(
|
||||
logits,
|
||||
next_n,
|
||||
decode_metadata.seq_lens,
|
||||
seq_lens,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
@@ -238,7 +229,7 @@ def sparse_attn_indexer(
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits,
|
||||
next_n,
|
||||
decode_metadata.seq_lens,
|
||||
seq_lens,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
|
||||
@@ -141,11 +141,14 @@ class DeepseekV32IndexerPrefillMetadata:
|
||||
@dataclass
|
||||
class DeepSeekV32IndexerDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
# seq_lens: per-token effective context lengths.
|
||||
# - flatten path / plain decode: 1D (batch_size,)
|
||||
# - native MTP path: 2D (B, next_n) where [b,j] = L_b - next_n + j + 1
|
||||
# Both fp8_paged_mqa_logits and the topk kernels accept both shapes.
|
||||
seq_lens: torch.Tensor
|
||||
decode_lens: torch.Tensor
|
||||
requires_padding: bool
|
||||
schedule_metadata: torch.Tensor
|
||||
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -283,21 +286,31 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
sm_count = num_compute_units(self.device.index)
|
||||
self.num_sms = sm_count
|
||||
|
||||
self.decode_lens_buffer = torch.empty(
|
||||
(scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.offsets_buffer = torch.arange(
|
||||
next_n, device=self.device, dtype=torch.int32
|
||||
)
|
||||
self.arange_buffer = torch.arange(
|
||||
scheduler_config.max_num_seqs * next_n,
|
||||
self.decode_lens_buffer = torch.zeros(
|
||||
(scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.expanded_seq_lens_buffer = torch.zeros(
|
||||
(scheduler_config.max_num_batched_tokens,),
|
||||
if not self.use_flattening and next_n > 1:
|
||||
# Native MTP: 2D buffer for per-token seq_lens.
|
||||
# Flattening path is never used, so no expanded_seq_lens_buffer.
|
||||
self.decode_seq_lens_buffer = torch.zeros(
|
||||
(scheduler_config.max_num_seqs, next_n),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
# Flattening or no MTP: 1D buffer for expanded per-token seq_lens.
|
||||
self.decode_seq_lens_buffer = torch.zeros(
|
||||
(scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.arange_buffer = torch.arange(
|
||||
scheduler_config.max_num_seqs * next_n,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
@@ -367,6 +380,96 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
skip_kv_gather=skip_kv_gather,
|
||||
)
|
||||
|
||||
def _prepare_decode_tensors(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
decode_lens: torch.Tensor,
|
||||
decode_lens_cpu: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
num_decodes: int,
|
||||
num_decode_tokens: int,
|
||||
use_native: bool,
|
||||
next_n: int,
|
||||
max_decode_len: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, bool]:
|
||||
"""Expand seq_lens/block_table/decode_lens for the decode kernels.
|
||||
|
||||
Flatten path (not use_native, max_decode_len > 1):
|
||||
Each multi-token decode request is expanded into individual
|
||||
single-token entries so the kernel always sees next_n=1.
|
||||
|
||||
Native path (use_native or max_decode_len == 1):
|
||||
Plain decode or spec-decode with 2D per-token context lengths.
|
||||
|
||||
Returns (seq_lens, block_table, decode_lens, batch_size, requires_padding).
|
||||
seq_lens is 1D (batch_size,) for flatten/plain, 2D (B, next_n) for native MTP.
|
||||
"""
|
||||
if not use_native and max_decode_len > 1:
|
||||
assert self.decode_seq_lens_buffer.dim() == 1
|
||||
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
|
||||
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
|
||||
# The context lengths are therefore
|
||||
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
|
||||
|
||||
# 3 + 1 + 4 + 0 = 8
|
||||
actual_expanded = int(decode_lens_cpu.sum().item())
|
||||
|
||||
# Fuse expanded_base and expanded_starts into a single repeat_interleave:
|
||||
# seq_len_i = (context_start[b] - query_start_loc[b]) + arange[i] + 1
|
||||
# where context_start[b] = seq_lens[b] - decode_lens[b].
|
||||
# Example: offsets = [7-0, 6-3, 8-4, 0-8] = [7, 3, 4, -8]
|
||||
# expanded_offsets = [7, 7, 7, 3, 4, 4, 4, 4]
|
||||
# result = [8, 9, 10, 7, 9, 10, 11, 12]
|
||||
expanded_offsets = torch.repeat_interleave(
|
||||
seq_lens - decode_lens - query_start_loc,
|
||||
decode_lens,
|
||||
output_size=actual_expanded,
|
||||
)
|
||||
|
||||
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
|
||||
self.decode_seq_lens_buffer[:actual_expanded] = (
|
||||
expanded_offsets + self.arange_buffer[:actual_expanded] + 1
|
||||
)
|
||||
self.decode_seq_lens_buffer[actual_expanded:] = 0
|
||||
seq_lens = self.decode_seq_lens_buffer[:num_decode_tokens]
|
||||
|
||||
# Give each of the flattened entries the same block table row as the
|
||||
# original request.
|
||||
self.expanded_block_table_buffer[:actual_expanded] = (
|
||||
torch.repeat_interleave(
|
||||
block_table, decode_lens, dim=0, output_size=actual_expanded
|
||||
)
|
||||
)
|
||||
if actual_expanded < num_decode_tokens:
|
||||
self.expanded_block_table_buffer[
|
||||
actual_expanded:num_decode_tokens, 0
|
||||
] = 0
|
||||
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
|
||||
|
||||
# All reqs now have decode_len=1
|
||||
self.decode_lens_buffer[:num_decode_tokens] = 1
|
||||
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
|
||||
return seq_lens, block_table, decode_lens, num_decode_tokens, False
|
||||
else:
|
||||
# Native path: plain decode (next_n==1) or spec decode
|
||||
# with 2D per-token context lengths (next_n > 1).
|
||||
#
|
||||
# When decode_lens are not truly uniform (e.g. some requests have
|
||||
# decode_len < next_n due to padding or short prefills), the simple
|
||||
# reshape in sparse_attn_indexer won't work. Use pack_seq_triton
|
||||
# (requires_padding) instead.
|
||||
min_decode_len = int(decode_lens_cpu.min().item())
|
||||
requires_padding = min_decode_len != max_decode_len
|
||||
if use_native and next_n > 1:
|
||||
assert self.decode_seq_lens_buffer.dim() == 2
|
||||
# (B, next_n): token j attends to L - next_n + j + 1 KV tokens
|
||||
self.decode_seq_lens_buffer[:num_decodes] = (
|
||||
seq_lens.unsqueeze(1) - next_n + 1 + self.offsets_buffer
|
||||
)
|
||||
seq_lens = self.decode_seq_lens_buffer[:num_decodes]
|
||||
return seq_lens, block_table, decode_lens, num_decodes, requires_padding
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
@@ -434,68 +537,20 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
use_native = not self.use_flattening and max_decode_len == next_n
|
||||
|
||||
if use_native and next_n > 1:
|
||||
offsets = self.offsets_buffer
|
||||
elif max_decode_len > 1:
|
||||
# Flatten multi-token decode requests into single-token
|
||||
# batch entries, expanding seq_lens and block tables so
|
||||
# the kernel always sees next_n=1.
|
||||
|
||||
# Also handles the edge case where use_flattening=False
|
||||
# but max_decode_len != next_n (e.g. a batch containing some
|
||||
# short prefills (q_len < next_n) and no true decodes).
|
||||
|
||||
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
|
||||
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
|
||||
# The context lengths are therefore
|
||||
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
|
||||
|
||||
# 3 + 1 + 4 + 0 = 8
|
||||
actual_expanded = int(decode_lens_cpu.sum().item())
|
||||
|
||||
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
|
||||
expanded_base = torch.repeat_interleave(
|
||||
seq_lens - decode_lens, decode_lens, output_size=actual_expanded
|
||||
seq_lens, block_table, decode_lens, batch_size, requires_padding = (
|
||||
self._prepare_decode_tensors(
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table,
|
||||
decode_lens=decode_lens,
|
||||
decode_lens_cpu=decode_lens_cpu,
|
||||
query_start_loc=common_attn_metadata.query_start_loc[:num_decodes],
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
use_native=use_native,
|
||||
next_n=next_n,
|
||||
max_decode_len=max_decode_len,
|
||||
)
|
||||
|
||||
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
|
||||
expanded_starts = torch.repeat_interleave(
|
||||
common_attn_metadata.query_start_loc[:num_decodes],
|
||||
decode_lens,
|
||||
output_size=actual_expanded,
|
||||
)
|
||||
|
||||
# [0, 1, 2, 0, 0, 1, 2, 3]
|
||||
positions_within = (
|
||||
self.arange_buffer[:actual_expanded] - expanded_starts
|
||||
)
|
||||
|
||||
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
|
||||
self.expanded_seq_lens_buffer[:actual_expanded] = (
|
||||
expanded_base + positions_within + 1
|
||||
)
|
||||
self.expanded_seq_lens_buffer[actual_expanded:] = 0
|
||||
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
|
||||
|
||||
# Give each of the flattened entries the same block table row as the
|
||||
# original request.
|
||||
self.expanded_block_table_buffer[:actual_expanded] = (
|
||||
torch.repeat_interleave(
|
||||
block_table, decode_lens, dim=0, output_size=actual_expanded
|
||||
)
|
||||
)
|
||||
if actual_expanded < num_decode_tokens:
|
||||
self.expanded_block_table_buffer[
|
||||
actual_expanded:num_decode_tokens, 0
|
||||
] = 0
|
||||
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
|
||||
|
||||
# All reqs now have decode_len=1
|
||||
self.decode_lens_buffer[:num_decode_tokens] = 1
|
||||
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
|
||||
offsets = None
|
||||
else:
|
||||
offsets = None
|
||||
)
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
if current_platform.is_cuda() and has_deep_gemm():
|
||||
@@ -509,9 +564,8 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=False,
|
||||
requires_padding=requires_padding,
|
||||
schedule_metadata=self.scheduler_metadata_buffer,
|
||||
offsets=offsets,
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
@@ -531,6 +585,4 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||
return attn_metadata
|
||||
|
||||
Reference in New Issue
Block a user