[Attention][CUDAGraph] Remove CG padding from attention backends (#29352)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -252,7 +252,6 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
num_padded_decodes = attn_metadata.num_padded_decodes
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@@ -281,7 +280,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
state_indices_tensor,
|
||||
num_prefill_tokens,
|
||||
num_prefills,
|
||||
num_padded_decodes,
|
||||
num_decode_tokens,
|
||||
)
|
||||
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
|
||||
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
|
||||
@@ -470,24 +469,24 @@ def split_batch_to_prefill_and_decode(
|
||||
state_indices_tensor: torch.Tensor,
|
||||
num_prefill_tokens: int,
|
||||
num_prefills: int,
|
||||
num_padded_decodes: int,
|
||||
num_decode_tokens: int,
|
||||
) -> PrefillDecodeSplit:
|
||||
num_actual_tokens = num_prefill_tokens + num_padded_decodes
|
||||
num_actual_tokens = num_prefill_tokens + num_decode_tokens
|
||||
|
||||
# In v1, decode tokens come first, then prefill tokens.
|
||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||
hidden_states_BC[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
[num_decode_tokens, num_prefill_tokens],
|
||||
dim=-1,
|
||||
)
|
||||
gate_d, gate_p = torch.split(
|
||||
gate[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], dim=-1
|
||||
gate[..., :num_actual_tokens], [num_decode_tokens, num_prefill_tokens], dim=-1
|
||||
)
|
||||
|
||||
# num_padded_decodes accounts for CUDA graph padding when applicable
|
||||
# num_decode_tokens accounts for CUDA graph padding when applicable
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[: num_padded_decodes + num_prefills],
|
||||
[num_padded_decodes, num_prefills],
|
||||
state_indices_tensor[: num_decode_tokens + num_prefills],
|
||||
[num_decode_tokens, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user