[Bugfix] Disable CG for Whisper+FA2 (#33164)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -257,6 +257,26 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
)
|
||||
supports_update_block_table: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
kv_cache_spec: "AttentionSpec",
|
||||
) -> AttentionCGSupport:
|
||||
# FA2 does not support CUDA graphs with encoder-decoder models due to
|
||||
# accuracy issues reported in https://github.com/vllm-project/vllm/issues/33091
|
||||
if (
|
||||
vllm_config.model_config.is_encoder_decoder
|
||||
and get_flash_attn_version() == 2
|
||||
):
|
||||
logger.warning_once(
|
||||
"FlashAttention2 does not support CUDA graphs with "
|
||||
"encoder-decoder models due to accuracy issues reported in #33091. "
|
||||
"Disabling CUDA graph."
|
||||
)
|
||||
return AttentionCGSupport.NEVER
|
||||
return cls._cudagraph_support
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
|
||||
Reference in New Issue
Block a user