diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index d94055cbe..f8ff2fc2e 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -384,12 +384,14 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): # [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8] expanded_base = torch.repeat_interleave( - seq_lens - decode_lens, decode_lens + seq_lens - decode_lens, decode_lens, output_size=actual_expanded ) # [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4] expanded_starts = torch.repeat_interleave( - common_attn_metadata.query_start_loc[:num_decodes], decode_lens + common_attn_metadata.query_start_loc[:num_decodes], + decode_lens, + output_size=actual_expanded, ) # [0, 1, 2, 0, 0, 1, 2, 3] @@ -407,7 +409,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): # Give each of the flattened entries the same block table row as the # original request. self.expanded_block_table_buffer[:actual_expanded] = ( - torch.repeat_interleave(block_table, decode_lens, dim=0) + torch.repeat_interleave( + block_table, decode_lens, dim=0, output_size=actual_expanded + ) ) if actual_expanded < num_decode_tokens: self.expanded_block_table_buffer[