[Attention][CUDAGraph] Remove CG padding from attention backends (#29352)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2025-12-02 13:48:08 -05:00
committed by GitHub
parent 2d613de9ae
commit 1d93f11675
5 changed files with 20 additions and 46 deletions

View File

@@ -254,17 +254,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
)
else:
has_initial_state = None
num_actual_tokens = (
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
)
# prepare tensors for cudagraph
#
# With speculative decoding, the xgrammar backend may rollback tokens
# and causing some sequences has less draft tokens than self.num_spec.
#
# In above cases, the max possible batch size for n tokens, can be
# min(n, cudagraph_max_bs).
# Prepare tensors for cudagraph
# Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph
batch_size = m.num_actual_tokens
if (
self.use_full_cuda_graph
and num_prefills == 0
@@ -272,9 +266,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
and num_spec_decodes <= self.decode_cudagraph_max_bs
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True
)
@@ -319,9 +310,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs
):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
batch_size = num_actual_tokens
self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True
)
@@ -344,7 +332,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens,
num_actual_tokens=num_actual_tokens,
num_actual_tokens=m.num_actual_tokens,
has_initial_state=has_initial_state,
spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc,