[Attention][CUDAGraph] Remove CG padding from attention backends (#29352)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -254,17 +254,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
)
|
||||
else:
|
||||
has_initial_state = None
|
||||
num_actual_tokens = (
|
||||
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
|
||||
)
|
||||
|
||||
# prepare tensors for cudagraph
|
||||
#
|
||||
# With speculative decoding, the xgrammar backend may rollback tokens
|
||||
# and causing some sequences has less draft tokens than self.num_spec.
|
||||
#
|
||||
# In above cases, the max possible batch size for n tokens, can be
|
||||
# min(n, cudagraph_max_bs).
|
||||
# Prepare tensors for cudagraph
|
||||
# Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph
|
||||
batch_size = m.num_actual_tokens
|
||||
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
@@ -272,9 +266,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
and num_spec_decodes <= self.decode_cudagraph_max_bs
|
||||
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
|
||||
|
||||
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
|
||||
spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
@@ -319,9 +310,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
and num_spec_decodes == 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||
batch_size = num_actual_tokens
|
||||
|
||||
self.non_spec_state_indices_tensor[:num_decodes].copy_(
|
||||
non_spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
@@ -344,7 +332,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_spec_decodes=num_spec_decodes,
|
||||
num_spec_decode_tokens=num_spec_decode_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_actual_tokens=m.num_actual_tokens,
|
||||
has_initial_state=has_initial_state,
|
||||
spec_query_start_loc=spec_query_start_loc,
|
||||
non_spec_query_start_loc=non_spec_query_start_loc,
|
||||
|
||||
Reference in New Issue
Block a user