diff --git a/csrc/sampler.cu b/csrc/sampler.cu index 2e76873c8..c0cc03a08 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -564,8 +564,9 @@ template static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode( const float* logits, const int* seqLens, int* outIndices, int stride0, - int stride1, const int topK, int next_n, float* outLogits = nullptr, - const int numBlocksToMerge = 0, const int* indices = nullptr) { + int stride1, const int topK, int next_n, int seqLensIs2D = 0, + float* outLogits = nullptr, const int numBlocksToMerge = 0, + const int* indices = nullptr) { // The number of bins in the histogram. static constexpr int kNumBins = 2048; @@ -574,8 +575,16 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode( // The range of logits within the row. int rowStart = 0; - int seq_len = seqLens[rowIdx / next_n]; - int rowEnd = max(0, seq_len - next_n + (rowIdx % next_n) + 1); + int batch_idx = rowIdx / next_n; + int next_n_idx = rowIdx % next_n; + // seqLensIs2D=0: 1D seqLens — all rows in a batch share the same seq_len; + // kernel computes per-row effective length via offset. + // seqLensIs2D=1: 2D seqLens — each logit row has its own pre-computed + // effective length (flat index rowIdx = b*next_n + j maps + // directly to seqLens[b, j] in C-contiguous layout). + int seq_len = seqLensIs2D ? seqLens[rowIdx] : seqLens[batch_idx]; + int rowEnd = + seqLensIs2D ? max(0, seq_len) : max(0, seq_len - next_n + next_n_idx + 1); // Local pointers to this block if constexpr (!multipleBlocksPerRow && !mergeBlocks) { @@ -653,6 +662,11 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const auto numColumns = logits.size(1); + // True if seqLens is 2D (B, next_n): each logit row has its own pre-computed + // effective seq_len. False if seqLens is 1D (B,): all rows in a batch share + // the same seq_len and the kernel computes the per-row offset itself. + int seqLensIs2D = seqLens.dim() == 2 ? 1 : 0; + if (numColumns < kSortingAlgorithmThreshold) { // Use insertion sort vllm::topKPerRowDecode @@ -660,7 +674,7 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, logits.data_ptr(), seqLens.data_ptr(), indices.data_ptr(), static_cast(stride0), static_cast(stride1), static_cast(topK), - static_cast(next_n)); + static_cast(next_n), seqLensIs2D); } else if (numColumns < kSplitWorkThreshold) { // From this threshold, use radix sort instead vllm::topKPerRowDecode @@ -668,7 +682,7 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, logits.data_ptr(), seqLens.data_ptr(), indices.data_ptr(), static_cast(stride0), static_cast(stride1), static_cast(topK), - static_cast(next_n)); + static_cast(next_n), seqLensIs2D); } else { // Long sequences are run in two steps constexpr auto multipleBlocksPerRowConfig = 10; @@ -686,15 +700,16 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, logits.data_ptr(), seqLens.data_ptr(), outIndicesAux.data_ptr(), static_cast(stride0), static_cast(stride1), static_cast(topK), - static_cast(next_n), outLogitsAux.data_ptr()); + static_cast(next_n), seqLensIs2D, + outLogitsAux.data_ptr()); constexpr int kNumThreadsPerBlockMerge = 1024; vllm::topKPerRowDecode <<>>( outLogitsAux.data_ptr(), seqLens.data_ptr(), indices.data_ptr(), multipleBlocksPerRowConfig * topK, 1, - static_cast(topK), static_cast(next_n), nullptr, - multipleBlocksPerRowConfig, outIndicesAux.data_ptr()); + static_cast(topK), static_cast(next_n), seqLensIs2D, + nullptr, multipleBlocksPerRowConfig, outIndicesAux.data_ptr()); } } diff --git a/csrc/topk.cu b/csrc/topk.cu index 402b64b02..f48e7cbc4 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -21,13 +21,15 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32"); TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32"); TORCH_CHECK(logits.dim() == 2, "logits must be 2D"); - TORCH_CHECK(lengths.dim() == 1, "lengths must be 1D"); + TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2, + "lengths must be 1D or 2D"); + TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous"); TORCH_CHECK(output.dim() == 2, "output must be 2D"); const int64_t num_rows = logits.size(0); const int64_t stride = logits.size(1); - TORCH_CHECK(lengths.size(0) == num_rows, "lengths size mismatch"); + TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch"); TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k, "output size mismatch"); namespace P = vllm::persistent; diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 1844b7556..9e1cd19a0 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -183,13 +183,15 @@ def sparse_attn_indexer( # TODO: move and optimize below logic with triton kernels batch_size = padded_q_fp8_decode_tokens.shape[0] next_n = padded_q_fp8_decode_tokens.shape[1] - assert batch_size == decode_metadata.seq_lens.shape[0] num_padded_tokens = batch_size * next_n + seq_lens = decode_metadata.seq_lens[:batch_size] + # seq_lens is (B, next_n) for native spec decode, (B,) otherwise. + # fp8_paged_mqa_logits and all topk kernels accept both shapes. logits = fp8_paged_mqa_logits( padded_q_fp8_decode_tokens, kv_cache, weights[:num_padded_tokens], - decode_metadata.seq_lens, + seq_lens, decode_metadata.block_table, decode_metadata.schedule_metadata, max_model_len=max_model_len, @@ -198,17 +200,6 @@ def sparse_attn_indexer( num_rows = logits.shape[0] topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - if next_n == 1: - lengths = decode_metadata.seq_lens - else: - # (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,) - lengths = ( - decode_metadata.seq_lens.unsqueeze(1) - - next_n - + 1 - + decode_metadata.offsets - ).flatten() - if current_platform.is_cuda(): workspace_manager = current_workspace_manager() (topk_workspace,) = workspace_manager.get_simultaneous( @@ -216,7 +207,7 @@ def sparse_attn_indexer( ) torch.ops._C.persistent_topk( logits, - lengths, + decode_metadata.seq_lens, topk_indices, topk_workspace, topk_tokens, @@ -227,7 +218,7 @@ def sparse_attn_indexer( ops.top_k_per_row_decode( logits, next_n, - decode_metadata.seq_lens, + seq_lens, topk_indices, num_rows, logits.stride(0), @@ -238,7 +229,7 @@ def sparse_attn_indexer( torch.ops._C.top_k_per_row_decode( logits, next_n, - decode_metadata.seq_lens, + seq_lens, topk_indices, num_rows, logits.stride(0), diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 5cb4b46a7..402dfc0c7 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -141,11 +141,14 @@ class DeepseekV32IndexerPrefillMetadata: @dataclass class DeepSeekV32IndexerDecodeMetadata: block_table: torch.Tensor + # seq_lens: per-token effective context lengths. + # - flatten path / plain decode: 1D (batch_size,) + # - native MTP path: 2D (B, next_n) where [b,j] = L_b - next_n + j + 1 + # Both fp8_paged_mqa_logits and the topk kernels accept both shapes. seq_lens: torch.Tensor decode_lens: torch.Tensor requires_padding: bool schedule_metadata: torch.Tensor - offsets: torch.Tensor | None # Precomputed offsets for speculative decoding @dataclass @@ -283,21 +286,31 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): sm_count = num_compute_units(self.device.index) self.num_sms = sm_count - self.decode_lens_buffer = torch.empty( - (scheduler_config.max_num_batched_tokens,), - dtype=torch.int32, - device=self.device, - ) self.offsets_buffer = torch.arange( next_n, device=self.device, dtype=torch.int32 ) - self.arange_buffer = torch.arange( - scheduler_config.max_num_seqs * next_n, + self.decode_lens_buffer = torch.zeros( + (scheduler_config.max_num_batched_tokens,), dtype=torch.int32, device=self.device, ) - self.expanded_seq_lens_buffer = torch.zeros( - (scheduler_config.max_num_batched_tokens,), + if not self.use_flattening and next_n > 1: + # Native MTP: 2D buffer for per-token seq_lens. + # Flattening path is never used, so no expanded_seq_lens_buffer. + self.decode_seq_lens_buffer = torch.zeros( + (scheduler_config.max_num_seqs, next_n), + dtype=torch.int32, + device=self.device, + ) + else: + # Flattening or no MTP: 1D buffer for expanded per-token seq_lens. + self.decode_seq_lens_buffer = torch.zeros( + (scheduler_config.max_num_batched_tokens,), + dtype=torch.int32, + device=self.device, + ) + self.arange_buffer = torch.arange( + scheduler_config.max_num_seqs * next_n, dtype=torch.int32, device=self.device, ) @@ -367,6 +380,96 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): skip_kv_gather=skip_kv_gather, ) + def _prepare_decode_tensors( + self, + seq_lens: torch.Tensor, + block_table: torch.Tensor, + decode_lens: torch.Tensor, + decode_lens_cpu: torch.Tensor, + query_start_loc: torch.Tensor, + num_decodes: int, + num_decode_tokens: int, + use_native: bool, + next_n: int, + max_decode_len: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, bool]: + """Expand seq_lens/block_table/decode_lens for the decode kernels. + + Flatten path (not use_native, max_decode_len > 1): + Each multi-token decode request is expanded into individual + single-token entries so the kernel always sees next_n=1. + + Native path (use_native or max_decode_len == 1): + Plain decode or spec-decode with 2D per-token context lengths. + + Returns (seq_lens, block_table, decode_lens, batch_size, requires_padding). + seq_lens is 1D (batch_size,) for flatten/plain, 2D (B, next_n) for native MTP. + """ + if not use_native and max_decode_len > 1: + assert self.decode_seq_lens_buffer.dim() == 1 + # Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is + # padding) and decode_lens [3, 1, 4, 0] in the below example comments. + # The context lengths are therefore + # [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0]. + + # 3 + 1 + 4 + 0 = 8 + actual_expanded = int(decode_lens_cpu.sum().item()) + + # Fuse expanded_base and expanded_starts into a single repeat_interleave: + # seq_len_i = (context_start[b] - query_start_loc[b]) + arange[i] + 1 + # where context_start[b] = seq_lens[b] - decode_lens[b]. + # Example: offsets = [7-0, 6-3, 8-4, 0-8] = [7, 3, 4, -8] + # expanded_offsets = [7, 7, 7, 3, 4, 4, 4, 4] + # result = [8, 9, 10, 7, 9, 10, 11, 12] + expanded_offsets = torch.repeat_interleave( + seq_lens - decode_lens - query_start_loc, + decode_lens, + output_size=actual_expanded, + ) + + # [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space + self.decode_seq_lens_buffer[:actual_expanded] = ( + expanded_offsets + self.arange_buffer[:actual_expanded] + 1 + ) + self.decode_seq_lens_buffer[actual_expanded:] = 0 + seq_lens = self.decode_seq_lens_buffer[:num_decode_tokens] + + # Give each of the flattened entries the same block table row as the + # original request. + self.expanded_block_table_buffer[:actual_expanded] = ( + torch.repeat_interleave( + block_table, decode_lens, dim=0, output_size=actual_expanded + ) + ) + if actual_expanded < num_decode_tokens: + self.expanded_block_table_buffer[ + actual_expanded:num_decode_tokens, 0 + ] = 0 + block_table = self.expanded_block_table_buffer[:num_decode_tokens] + + # All reqs now have decode_len=1 + self.decode_lens_buffer[:num_decode_tokens] = 1 + decode_lens = self.decode_lens_buffer[:num_decode_tokens] + return seq_lens, block_table, decode_lens, num_decode_tokens, False + else: + # Native path: plain decode (next_n==1) or spec decode + # with 2D per-token context lengths (next_n > 1). + # + # When decode_lens are not truly uniform (e.g. some requests have + # decode_len < next_n due to padding or short prefills), the simple + # reshape in sparse_attn_indexer won't work. Use pack_seq_triton + # (requires_padding) instead. + min_decode_len = int(decode_lens_cpu.min().item()) + requires_padding = min_decode_len != max_decode_len + if use_native and next_n > 1: + assert self.decode_seq_lens_buffer.dim() == 2 + # (B, next_n): token j attends to L - next_n + j + 1 KV tokens + self.decode_seq_lens_buffer[:num_decodes] = ( + seq_lens.unsqueeze(1) - next_n + 1 + self.offsets_buffer + ) + seq_lens = self.decode_seq_lens_buffer[:num_decodes] + return seq_lens, block_table, decode_lens, num_decodes, requires_padding + def build( self, common_prefix_len: int, @@ -434,68 +537,20 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): next_n = 1 + self.num_speculative_tokens use_native = not self.use_flattening and max_decode_len == next_n - if use_native and next_n > 1: - offsets = self.offsets_buffer - elif max_decode_len > 1: - # Flatten multi-token decode requests into single-token - # batch entries, expanding seq_lens and block tables so - # the kernel always sees next_n=1. - - # Also handles the edge case where use_flattening=False - # but max_decode_len != next_n (e.g. a batch containing some - # short prefills (q_len < next_n) and no true decodes). - - # Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is - # padding) and decode_lens [3, 1, 4, 0] in the below example comments. - # The context lengths are therefore - # [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0]. - - # 3 + 1 + 4 + 0 = 8 - actual_expanded = int(decode_lens_cpu.sum().item()) - - # [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8] - expanded_base = torch.repeat_interleave( - seq_lens - decode_lens, decode_lens, output_size=actual_expanded + seq_lens, block_table, decode_lens, batch_size, requires_padding = ( + self._prepare_decode_tensors( + seq_lens=seq_lens, + block_table=block_table, + decode_lens=decode_lens, + decode_lens_cpu=decode_lens_cpu, + query_start_loc=common_attn_metadata.query_start_loc[:num_decodes], + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + use_native=use_native, + next_n=next_n, + max_decode_len=max_decode_len, ) - - # [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4] - expanded_starts = torch.repeat_interleave( - common_attn_metadata.query_start_loc[:num_decodes], - decode_lens, - output_size=actual_expanded, - ) - - # [0, 1, 2, 0, 0, 1, 2, 3] - positions_within = ( - self.arange_buffer[:actual_expanded] - expanded_starts - ) - - # [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space - self.expanded_seq_lens_buffer[:actual_expanded] = ( - expanded_base + positions_within + 1 - ) - self.expanded_seq_lens_buffer[actual_expanded:] = 0 - seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens] - - # Give each of the flattened entries the same block table row as the - # original request. - self.expanded_block_table_buffer[:actual_expanded] = ( - torch.repeat_interleave( - block_table, decode_lens, dim=0, output_size=actual_expanded - ) - ) - if actual_expanded < num_decode_tokens: - self.expanded_block_table_buffer[ - actual_expanded:num_decode_tokens, 0 - ] = 0 - block_table = self.expanded_block_table_buffer[:num_decode_tokens] - - # All reqs now have decode_len=1 - self.decode_lens_buffer[:num_decode_tokens] = 1 - decode_lens = self.decode_lens_buffer[:num_decode_tokens] - offsets = None - else: - offsets = None + ) # DeepGEMM is required for the paged MQA logits on CUDA devices if current_platform.is_cuda() and has_deep_gemm(): @@ -509,9 +564,8 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): block_table=block_table, seq_lens=seq_lens, decode_lens=decode_lens, - requires_padding=False, + requires_padding=requires_padding, schedule_metadata=self.scheduler_metadata_buffer, - offsets=offsets, ) attn_metadata = DeepseekV32IndexerMetadata( @@ -531,6 +585,4 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): decode=decode_metadata, ) - # if get_tensor_model_parallel_rank() == 0: - # logger.info(f"attn_metadata: {attn_metadata}") return attn_metadata