diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 41109ff41..c7a41abe5 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -208,7 +208,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] non_spec_query_lens = query_lens[~spec_sequence_masks] num_decodes = (non_spec_query_lens == 1).sum().item() - num_prefills = non_spec_query_lens.size(0) - num_decodes + # Exclude zero-length padded sequences from prefill count. + num_zero_len = (non_spec_query_lens == 0).sum().item() + num_prefills = non_spec_query_lens.size(0) - num_decodes - num_zero_len num_decode_tokens = num_decodes num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens num_spec_decode_tokens = ( @@ -228,9 +230,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] non_spec_token_indx = torch.empty( 0, dtype=torch.int32, device=query_start_loc.device ) - spec_state_indices_tensor = block_table_tensor[:, : self.num_spec + 1] + # Filter by spec_sequence_masks to exclude padded sequences + spec_state_indices_tensor = block_table_tensor[ + spec_sequence_masks, : self.num_spec + 1 + ] non_spec_state_indices_tensor = None - spec_query_start_loc = query_start_loc + # Padded sequences are always at the back, so the first + # num_spec_decodes + 1 entries of query_start_loc already + # contain the correct cumulative token counts. + spec_query_start_loc = query_start_loc[: num_spec_decodes + 1] non_spec_query_start_loc = None non_spec_query_start_loc_cpu = None else: @@ -294,6 +302,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] else: has_initial_state = None + # Function code counted on either presency non-spec decode or spec decode, + # but not both. + assert not (num_decodes > 0 and num_spec_decodes > 0), ( + f"num_decodes: {num_decodes}, num_spec_decodes: {num_spec_decodes}" + ) + # Prepare tensors for cudagraph # Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph batch_size = m.num_actual_tokens @@ -312,7 +326,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID) self.spec_sequence_masks[:num_spec_decodes].copy_( - spec_sequence_masks, non_blocking=True + spec_sequence_masks[:num_spec_decodes], non_blocking=True ) spec_sequence_masks = self.spec_sequence_masks[:batch_size] spec_sequence_masks[num_spec_decodes:].fill_(False)