[DSV3.2][MTP] Optimize Indexer MTP handling (#36723)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-03-11 00:16:56 -04:00
committed by GitHub
parent fa0d353acf
commit 9040cd40af

View File

@@ -384,12 +384,14 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
expanded_base = torch.repeat_interleave(
seq_lens - decode_lens, decode_lens
seq_lens - decode_lens, decode_lens, output_size=actual_expanded
)
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
expanded_starts = torch.repeat_interleave(
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
common_attn_metadata.query_start_loc[:num_decodes],
decode_lens,
output_size=actual_expanded,
)
# [0, 1, 2, 0, 0, 1, 2, 3]
@@ -407,7 +409,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
# Give each of the flattened entries the same block table row as the
# original request.
self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(block_table, decode_lens, dim=0)
torch.repeat_interleave(
block_table, decode_lens, dim=0, output_size=actual_expanded
)
)
if actual_expanded < num_decode_tokens:
self.expanded_block_table_buffer[