[DSV3.2][MTP] Optimize Indexer MTP handling (#36723)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
fa0d353acf
commit
9040cd40af
@@ -384,12 +384,14 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
|
||||
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
|
||||
expanded_base = torch.repeat_interleave(
|
||||
seq_lens - decode_lens, decode_lens
|
||||
seq_lens - decode_lens, decode_lens, output_size=actual_expanded
|
||||
)
|
||||
|
||||
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
|
||||
expanded_starts = torch.repeat_interleave(
|
||||
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
|
||||
common_attn_metadata.query_start_loc[:num_decodes],
|
||||
decode_lens,
|
||||
output_size=actual_expanded,
|
||||
)
|
||||
|
||||
# [0, 1, 2, 0, 0, 1, 2, 3]
|
||||
@@ -407,7 +409,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
# Give each of the flattened entries the same block table row as the
|
||||
# original request.
|
||||
self.expanded_block_table_buffer[:actual_expanded] = (
|
||||
torch.repeat_interleave(block_table, decode_lens, dim=0)
|
||||
torch.repeat_interleave(
|
||||
block_table, decode_lens, dim=0, output_size=actual_expanded
|
||||
)
|
||||
)
|
||||
if actual_expanded < num_decode_tokens:
|
||||
self.expanded_block_table_buffer[
|
||||
|
||||
Reference in New Issue
Block a user