[ROCm] Enabling encoder and encoder-decoder on ROCm and AITER unified backends (#35334)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2026-02-27 15:32:55 -06:00
committed by GitHub
parent 2ce6f3cf67
commit 9fa6c68fa6
3 changed files with 106 additions and 7 deletions

View File

@@ -171,8 +171,8 @@ Priority is **1 = highest** (tried first).
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A |
| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |

View File

@@ -55,6 +55,16 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
return RocmAttentionMetadataBuilder
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""RocmAiterUnifiedAttention supports all attention types."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
@@ -143,6 +153,19 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
layer,
)
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_cache_dtype.startswith("fp8"):
@@ -195,6 +218,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
key_cache, value_cache = kv_cache.unbind(0)
# Reshape the input keys and values and store them in the cache.
@@ -224,6 +251,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
key_cache, value_cache = kv_cache.unbind(0)
flash_layout = True

View File

@@ -205,6 +205,16 @@ class RocmAttentionBackend(AttentionBackend):
def get_impl_cls() -> type["RocmAttentionImpl"]:
return RocmAttentionImpl
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""RocmAttention supports all attention types."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
@@ -244,6 +254,7 @@ class RocmAttentionImpl(AttentionImpl):
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
) -> None:
self.attn_type = attn_type
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -266,11 +277,6 @@ class RocmAttentionImpl(AttentionImpl):
RocmAttentionBackend.validate_head_size(head_size)
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention is not implemented for RocmAttentionImpl"
)
self.fp8_dtype = current_platform.fp8_dtype()
self.sinks = sinks
@@ -281,6 +287,54 @@ class RocmAttentionImpl(AttentionImpl):
f"num_heads: {num_heads}."
)
def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
layer: torch.nn.Module,
) -> torch.Tensor:
"""Forward pass for encoder attention without KV cache.
Args:
query: shape = [num_encoder_tokens, num_heads, head_size]
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
output: shape = [num_encoder_tokens, num_heads, head_size]
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"quantization is not supported for encoder attention"
)
# Use encoder-specific metadata for sequence information
query_start_loc = attn_metadata.query_start_loc
seq_lens = attn_metadata.seq_lens
max_query_len = attn_metadata.max_query_len
# Call flash attention directly on Q, K, V tensors
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
context_attention_fwd(
q=query,
k=key,
v=value,
o=output,
b_start_loc=query_start_loc,
b_seq_len=seq_lens,
max_input_len=max_query_len,
is_causal=False,
softmax_scale=self.scale,
sliding_window_q=self.sliding_window[0],
sliding_window_k=self.sliding_window[1],
)
return output
def forward(
self,
layer: torch.nn.Module,
@@ -330,6 +384,16 @@ class RocmAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
layer,
)
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
@@ -380,6 +444,8 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
@@ -432,6 +498,8 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache,
layer.num_kv_heads, # type: ignore[attr-defined]