[ROCm][CI] Fix cross-attention dispatch for encoder-decoder models (#38450)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -758,11 +758,12 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""ROCM AITER FA supports decoder and encoder-decoder (cross) attention."""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
"""ENCODER_DECODER is not supported because the prefill path uses
|
||||
flash_attn_varlen_func with cu_seqlens_k set to decoder
|
||||
query_start_loc (not encoder seq lens) and causal=True, both of
|
||||
which are incorrect for cross-attention layers.
|
||||
"""
|
||||
return attn_type in (AttentionType.DECODER,)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
|
||||
@@ -212,12 +212,17 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""RocmAttention supports all attention types."""
|
||||
"""ENCODER_DECODER is not supported because
|
||||
chunked_prefill_paged_decode's prefill kernel (context_attention_fwd)
|
||||
assumes self-attention semantics: it treats passed K/V as new tokens
|
||||
to mix with cached K/V. For cross-attention layers the encoder K/V
|
||||
are already fully cached, so mixing them again produces incorrect
|
||||
results when max_query_len > 1 (e.g. beam search).
|
||||
"""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user