From f40d9879f2dfe4d878b77768ad30935ea4e42b1f Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Mon, 6 Apr 2026 16:39:37 +0100 Subject: [PATCH] [Models][GDN] Remove GPU/CPU syncs in `GDNAttentionMetadata.build` during speculative decoding (#38047) Signed-off-by: Lukas Geiger --- vllm/v1/attention/backends/gdn_attn.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 5ebf040be..85715e91a 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -253,7 +253,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ) # Filter by spec_sequence_masks to exclude padded sequences spec_state_indices_tensor = block_table_tensor[ - spec_sequence_masks, : self.num_spec + 1 + spec_sequence_masks_cpu, : self.num_spec + 1 ] non_spec_state_indices_tensor = None # Padded sequences are always at the back, so the first @@ -264,7 +264,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] non_spec_query_start_loc_cpu = None else: spec_token_masks = torch.repeat_interleave( - spec_sequence_masks, query_lens + spec_sequence_masks, + query_lens, + output_size=query_start_loc_cpu[-1].item(), ) index = torch.argsort(spec_token_masks, stable=True) num_non_spec_tokens = num_prefill_tokens + num_decode_tokens @@ -272,10 +274,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] spec_token_indx = index[num_non_spec_tokens:] spec_state_indices_tensor = block_table_tensor[ - spec_sequence_masks, : self.num_spec + 1 + spec_sequence_masks_cpu, : self.num_spec + 1 ] non_spec_state_indices_tensor = block_table_tensor[ - ~spec_sequence_masks, 0 + ~spec_sequence_masks_cpu, 0 ] spec_query_start_loc = torch.zeros( @@ -284,7 +286,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] device=query_start_loc.device, ) torch.cumsum( - query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:] + query_lens[spec_sequence_masks_cpu], + dim=0, + out=spec_query_start_loc[1:], ) non_spec_query_start_loc = torch.zeros( query_lens.size(0) - num_spec_decodes + 1, @@ -292,7 +296,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] device=query_start_loc.device, ) torch.cumsum( - query_lens[~spec_sequence_masks], + query_lens[~spec_sequence_masks_cpu], dim=0, out=non_spec_query_start_loc[1:], ) @@ -307,7 +311,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ) assert num_accepted_tokens is not None - num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] + num_accepted_tokens = num_accepted_tokens[spec_sequence_masks_cpu] chunk_indices: torch.Tensor | None = None chunk_offsets: torch.Tensor | None = None @@ -331,8 +335,8 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] if num_prefills > 0: has_initial_state = context_lens_tensor > 0 - if spec_sequence_masks is not None: - has_initial_state = has_initial_state[~spec_sequence_masks] + if spec_sequence_masks_cpu is not None: + has_initial_state = has_initial_state[~spec_sequence_masks_cpu] assert non_spec_query_start_loc_cpu is not None nums_dict, batch_ptr, token_chunk_offset_ptr = ( compute_causal_conv1d_metadata(