[BugFix] Fix whisper FA2 + full cudagraphs (#33360)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -263,18 +263,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
kv_cache_spec: "AttentionSpec",
|
kv_cache_spec: "AttentionSpec",
|
||||||
) -> AttentionCGSupport:
|
) -> AttentionCGSupport:
|
||||||
# FA2 does not support CUDA graphs with encoder-decoder models due to
|
|
||||||
# accuracy issues reported in https://github.com/vllm-project/vllm/issues/33091
|
|
||||||
if (
|
|
||||||
vllm_config.model_config.is_encoder_decoder
|
|
||||||
and get_flash_attn_version() == 2
|
|
||||||
):
|
|
||||||
logger.warning_once(
|
|
||||||
"FlashAttention2 does not support CUDA graphs with "
|
|
||||||
"encoder-decoder models due to accuracy issues reported in #33091. "
|
|
||||||
"Disabling CUDA graph."
|
|
||||||
)
|
|
||||||
return AttentionCGSupport.NEVER
|
|
||||||
return cls._cudagraph_support
|
return cls._cudagraph_support
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -1395,12 +1395,14 @@ class GPUModelRunner(
|
|||||||
num_scheduled_tokens: dict[str, int],
|
num_scheduled_tokens: dict[str, int],
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
num_reqs: int,
|
num_reqs: int,
|
||||||
|
for_cudagraph_capture: bool = False,
|
||||||
) -> tuple[torch.Tensor | None, np.ndarray | None]:
|
) -> tuple[torch.Tensor | None, np.ndarray | None]:
|
||||||
if not isinstance(kv_cache_spec, CrossAttentionSpec):
|
if not isinstance(kv_cache_spec, CrossAttentionSpec):
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Zero out buffer for padding requests that are not actually scheduled (CGs)
|
# Zero out buffer for padding requests that are not actually scheduled (CGs)
|
||||||
self.encoder_seq_lens.np[:num_reqs] = 0
|
self.encoder_seq_lens.np[:num_reqs] = 0
|
||||||
|
|
||||||
# Build encoder_seq_lens array mapping request indices to
|
# Build encoder_seq_lens array mapping request indices to
|
||||||
# encoder lengths for inputs scheduled in this batch
|
# encoder lengths for inputs scheduled in this batch
|
||||||
for req_id in num_scheduled_tokens:
|
for req_id in num_scheduled_tokens:
|
||||||
@@ -1417,6 +1419,15 @@ class GPUModelRunner(
|
|||||||
feature.mm_position.length for feature in req_state.mm_features
|
feature.mm_position.length for feature in req_state.mm_features
|
||||||
)
|
)
|
||||||
self.encoder_seq_lens.np[req_index] = encoder_input_tokens
|
self.encoder_seq_lens.np[req_index] = encoder_input_tokens
|
||||||
|
if for_cudagraph_capture:
|
||||||
|
# During CUDA graph capture, we need to use realistic encoder lengths
|
||||||
|
# so that max_seqlen_k is captured with the correct value.
|
||||||
|
max_encoder_len = getattr(
|
||||||
|
self.model_config.hf_config,
|
||||||
|
"max_source_positions",
|
||||||
|
self.max_encoder_len,
|
||||||
|
)
|
||||||
|
self.encoder_seq_lens.np[:num_reqs] = max_encoder_len
|
||||||
|
|
||||||
self.encoder_seq_lens.copy_to_gpu(num_reqs)
|
self.encoder_seq_lens.copy_to_gpu(num_reqs)
|
||||||
encoder_seq_lens = self.encoder_seq_lens.gpu[:num_reqs]
|
encoder_seq_lens = self.encoder_seq_lens.gpu[:num_reqs]
|
||||||
@@ -1834,6 +1845,7 @@ class GPUModelRunner(
|
|||||||
num_scheduled_tokens or {},
|
num_scheduled_tokens or {},
|
||||||
kv_cache_group.kv_cache_spec,
|
kv_cache_group.kv_cache_spec,
|
||||||
num_reqs_padded,
|
num_reqs_padded,
|
||||||
|
for_cudagraph_capture=for_cudagraph_capture,
|
||||||
)
|
)
|
||||||
if kv_cache_gid > 0:
|
if kv_cache_gid > 0:
|
||||||
cm.block_table_tensor = _get_block_table(kv_cache_gid)
|
cm.block_table_tensor = _get_block_table(kv_cache_gid)
|
||||||
|
|||||||
Reference in New Issue
Block a user