diff --git a/csrc/sampler.cu b/csrc/sampler.cu index 30bfef33c..2e76873c8 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -575,7 +575,7 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode( // The range of logits within the row. int rowStart = 0; int seq_len = seqLens[rowIdx / next_n]; - int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1; + int rowEnd = max(0, seq_len - next_n + (rowIdx % next_n) + 1); // Local pointers to this block if constexpr (!multipleBlocksPerRow && !mergeBlocks) { diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 70281b4a9..3b3be6ac9 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -206,6 +206,8 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig): class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): reorder_batch_threshold: int = 1 + natively_supported_next_n: list[int] = [1, 2] + # TODO (matt): integrate kernel with next_n = 4 support @classmethod def get_cudagraph_support( @@ -231,7 +233,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): if self.vllm_config.speculative_config else 0 ) + next_n = self.num_speculative_tokens + 1 self.reorder_batch_threshold += self.num_speculative_tokens + self.use_flattening = next_n not in self.natively_supported_next_n sm_count = num_compute_units(self.device.index) self.num_sms = sm_count @@ -241,10 +245,11 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): dtype=torch.int32, device=self.device, ) - - # Pre-allocated buffers for flattening (spec decode). + self.offsets_buffer = torch.arange( + next_n, device=self.device, dtype=torch.int32 + ) self.arange_buffer = torch.arange( - scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens), + scheduler_config.max_num_seqs * next_n, dtype=torch.int32, device=self.device, ) @@ -323,7 +328,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, decode_threshold=self.reorder_batch_threshold + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=not self.use_flattening, ) ) @@ -372,11 +379,21 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): block_table.clamp_(min=0) max_decode_len = int(decode_lens_cpu.max().item()) - if max_decode_len > 1: + next_n = 1 + self.num_speculative_tokens + use_native = not self.use_flattening and max_decode_len == next_n + + if use_native and next_n > 1: + offsets = self.offsets_buffer + batch_size = num_decodes + elif max_decode_len > 1: # Flatten multi-token decode requests into single-token # batch entries, expanding seq_lens and block tables so # the kernel always sees next_n=1. + # Also handles the edge case where use_flattening=False + # but max_decode_len != next_n (e.g. a batch containing some + # short prefills (q_len < next_n) and no true decodes). + # Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is # padding) and decode_lens [3, 1, 4, 0] in the below example comments. # The context lengths are therefore @@ -428,13 +445,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): offsets = None batch_size = num_decode_tokens else: - next_n = 1 + self.num_speculative_tokens - if next_n > 1: - offsets = torch.arange( - next_n, device=self.device, dtype=torch.int32 - ) - else: - offsets = None + offsets = None batch_size = num_decodes # DeepGEMM is required for the paged MQA logits on CUDA devices