[Attention][CUDAGraph] Remove CG padding from attention backends (#29352)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2025-12-02 13:48:08 -05:00
committed by GitHub
parent 2d613de9ae
commit 1d93f11675
5 changed files with 20 additions and 46 deletions

View File

@@ -252,7 +252,6 @@ class MambaMixer(MambaBase, CustomOp):
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
num_padded_decodes = attn_metadata.num_padded_decodes
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@@ -281,7 +280,7 @@ class MambaMixer(MambaBase, CustomOp):
state_indices_tensor,
num_prefill_tokens,
num_prefills,
num_padded_decodes,
num_decode_tokens,
)
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
@@ -470,24 +469,24 @@ def split_batch_to_prefill_and_decode(
state_indices_tensor: torch.Tensor,
num_prefill_tokens: int,
num_prefills: int,
num_padded_decodes: int,
num_decode_tokens: int,
) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_padded_decodes
num_actual_tokens = num_prefill_tokens + num_decode_tokens
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC[..., :num_actual_tokens],
[num_padded_decodes, num_prefill_tokens],
[num_decode_tokens, num_prefill_tokens],
dim=-1,
)
gate_d, gate_p = torch.split(
gate[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], dim=-1
gate[..., :num_actual_tokens], [num_decode_tokens, num_prefill_tokens], dim=-1
)
# num_padded_decodes accounts for CUDA graph padding when applicable
# num_decode_tokens accounts for CUDA graph padding when applicable
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[: num_padded_decodes + num_prefills],
[num_padded_decodes, num_prefills],
state_indices_tensor[: num_decode_tokens + num_prefills],
[num_decode_tokens, num_prefills],
dim=0,
)