[Models][GDN] Remove GPU/CPU syncs in GDNAttentionMetadata.build during speculative decoding (#38047)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger
2026-04-06 16:39:37 +01:00
committed by GitHub
parent 47e605092b
commit f40d9879f2

View File

@@ -253,7 +253,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
)
# Filter by spec_sequence_masks to exclude padded sequences
spec_state_indices_tensor = block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
spec_sequence_masks_cpu, : self.num_spec + 1
]
non_spec_state_indices_tensor = None
# Padded sequences are always at the back, so the first
@@ -264,7 +264,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_query_start_loc_cpu = None
else:
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens
spec_sequence_masks,
query_lens,
output_size=query_start_loc_cpu[-1].item(),
)
index = torch.argsort(spec_token_masks, stable=True)
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
@@ -272,10 +274,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_token_indx = index[num_non_spec_tokens:]
spec_state_indices_tensor = block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
spec_sequence_masks_cpu, : self.num_spec + 1
]
non_spec_state_indices_tensor = block_table_tensor[
~spec_sequence_masks, 0
~spec_sequence_masks_cpu, 0
]
spec_query_start_loc = torch.zeros(
@@ -284,7 +286,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
device=query_start_loc.device,
)
torch.cumsum(
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
query_lens[spec_sequence_masks_cpu],
dim=0,
out=spec_query_start_loc[1:],
)
non_spec_query_start_loc = torch.zeros(
query_lens.size(0) - num_spec_decodes + 1,
@@ -292,7 +296,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
device=query_start_loc.device,
)
torch.cumsum(
query_lens[~spec_sequence_masks],
query_lens[~spec_sequence_masks_cpu],
dim=0,
out=non_spec_query_start_loc[1:],
)
@@ -307,7 +311,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
)
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks_cpu]
chunk_indices: torch.Tensor | None = None
chunk_offsets: torch.Tensor | None = None
@@ -331,8 +335,8 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
if num_prefills > 0:
has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks]
if spec_sequence_masks_cpu is not None:
has_initial_state = has_initial_state[~spec_sequence_masks_cpu]
assert non_spec_query_start_loc_cpu is not None
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(