[main][BugFix] Fixed an accuracy bug of Qwen3-next-MTP when batched inferring (#30632)

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark
2025-12-14 17:32:16 +08:00
committed by GitHub
parent dcb31196da
commit add1b9d3de

View File

@@ -211,7 +211,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens
)
index = torch.argsort(spec_token_masks)
index = torch.argsort(spec_token_masks, stable=True)
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
non_spec_token_indx = index[:num_non_spec_tokens]
spec_token_indx = index[num_non_spec_tokens:]