[Bugfix] Fix GDN attention crash with mixed decode/spec-decode batches (#34871)
Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
@@ -220,6 +220,16 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
query_lens_cpu.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||
)
|
||||
|
||||
# num_decodes and num_spec_decodes are mutually exclusive.
|
||||
# Reclassify non-spec decodes as prefills when spec decodes
|
||||
# exist — the prefill kernel handles 1-token sequences with
|
||||
# initial state correctly, producing identical results.
|
||||
if num_decodes > 0 and num_spec_decodes > 0:
|
||||
num_prefills += num_decodes
|
||||
num_prefill_tokens += num_decode_tokens
|
||||
num_decodes = 0
|
||||
num_decode_tokens = 0
|
||||
|
||||
if num_prefills == 0 and num_decodes == 0:
|
||||
spec_token_size = min(
|
||||
num_spec_decodes * (self.num_spec + 1),
|
||||
|
||||
Reference in New Issue
Block a user