[Models][GDN] Remove GPU/CPU syncs in GDNAttentionMetadata.build during speculative decoding (#38047)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
@@ -253,7 +253,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
)
|
||||
# Filter by spec_sequence_masks to exclude padded sequences
|
||||
spec_state_indices_tensor = block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
spec_sequence_masks_cpu, : self.num_spec + 1
|
||||
]
|
||||
non_spec_state_indices_tensor = None
|
||||
# Padded sequences are always at the back, so the first
|
||||
@@ -264,7 +264,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_query_start_loc_cpu = None
|
||||
else:
|
||||
spec_token_masks = torch.repeat_interleave(
|
||||
spec_sequence_masks, query_lens
|
||||
spec_sequence_masks,
|
||||
query_lens,
|
||||
output_size=query_start_loc_cpu[-1].item(),
|
||||
)
|
||||
index = torch.argsort(spec_token_masks, stable=True)
|
||||
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
|
||||
@@ -272,10 +274,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_token_indx = index[num_non_spec_tokens:]
|
||||
|
||||
spec_state_indices_tensor = block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
spec_sequence_masks_cpu, : self.num_spec + 1
|
||||
]
|
||||
non_spec_state_indices_tensor = block_table_tensor[
|
||||
~spec_sequence_masks, 0
|
||||
~spec_sequence_masks_cpu, 0
|
||||
]
|
||||
|
||||
spec_query_start_loc = torch.zeros(
|
||||
@@ -284,7 +286,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
|
||||
query_lens[spec_sequence_masks_cpu],
|
||||
dim=0,
|
||||
out=spec_query_start_loc[1:],
|
||||
)
|
||||
non_spec_query_start_loc = torch.zeros(
|
||||
query_lens.size(0) - num_spec_decodes + 1,
|
||||
@@ -292,7 +296,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[~spec_sequence_masks],
|
||||
query_lens[~spec_sequence_masks_cpu],
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc[1:],
|
||||
)
|
||||
@@ -307,7 +311,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
)
|
||||
|
||||
assert num_accepted_tokens is not None
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks_cpu]
|
||||
|
||||
chunk_indices: torch.Tensor | None = None
|
||||
chunk_offsets: torch.Tensor | None = None
|
||||
@@ -331,8 +335,8 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
|
||||
if num_prefills > 0:
|
||||
has_initial_state = context_lens_tensor > 0
|
||||
if spec_sequence_masks is not None:
|
||||
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||
if spec_sequence_masks_cpu is not None:
|
||||
has_initial_state = has_initial_state[~spec_sequence_masks_cpu]
|
||||
assert non_spec_query_start_loc_cpu is not None
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(
|
||||
|
||||
Reference in New Issue
Block a user