[ROCm][CI] Fix cross-attention dispatch for encoder-decoder models (#38450)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -758,11 +758,12 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""ROCM AITER FA supports decoder and encoder-decoder (cross) attention."""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
"""ENCODER_DECODER is not supported because the prefill path uses
|
||||
flash_attn_varlen_func with cu_seqlens_k set to decoder
|
||||
query_start_loc (not encoder seq lens) and causal=True, both of
|
||||
which are incorrect for cross-attention layers.
|
||||
"""
|
||||
return attn_type in (AttentionType.DECODER,)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
|
||||
Reference in New Issue
Block a user