diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index fefbf6a41..d59b74782 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -2070,5 +2070,56 @@ class rocm_aiter_ops: out_=out_, ) + @staticmethod + def paged_attention_common( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + tmp_out: torch.Tensor, + max_logits: torch.Tensor, + exp_sums: torch.Tensor, + max_seq_len: int, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_tables_stride0: int, + scale: float, + K_QScale_hip: torch.Tensor, + V_QScale_hip: torch.Tensor, + K_QScale_asm: torch.Tensor, + V_QScale_asm: torch.Tensor, + out_: torch.Tensor, + kv_cache_dtype: str, + ): + """ + Paged attention common function. + + This function is NOT wrapped with @is_aiter_supported decorator + to allow explicit backend selection via attention_config to work + even when VLLM_ROCM_USE_AITER=0. + + Note: This performs lazy import of aiter.paged_attention_common + """ + from aiter import paged_attention_common + + return paged_attention_common( + Q=Q, + K=K, + V=V, + tmp_out=tmp_out, + max_logits=max_logits, + exp_sums=exp_sums, + max_seq_len=max_seq_len, + block_tables=block_tables, + context_lens=context_lens, + block_tables_stride0=block_tables_stride0, + scale=scale, + K_QScale_hip=K_QScale_hip, + V_QScale_hip=V_QScale_hip, + K_QScale_asm=K_QScale_asm, + V_QScale_asm=V_QScale_asm, + out_=out_, + kv_cache_dtype=kv_cache_dtype, + ) + rocm_aiter_ops.register_ops_once() diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 6c6e82b1b..d0aebf614 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -1247,7 +1247,23 @@ class AiterFlashAttentionImpl(AttentionImpl): v_descale=layer._v_scale.expand(descale_shape), ) elif rocm_aiter_ops.is_shuffle_kv_cache_enabled(): - num_blocks, block_size, num_kv_heads, head_size = key_cache.shape + _, num_heads, head_size = query.shape + num_seqs = attn_metadata.seq_lens.shape[0] + max_num_partitions = ( + attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM + tmp_out = torch.empty( + (num_seqs, num_heads, max_num_partitions, head_size), + dtype=query.dtype, + device=query.device, + ) + exp_sums = torch.empty( + (num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=query.device, + ) + max_logits = torch.empty_like(exp_sums) + num_blocks, block_size, num_kv_heads, _ = key_cache.shape x = 16 // key_cache.element_size() k_cache_template = torch.empty( [num_blocks, num_kv_heads, head_size // x, block_size, x], @@ -1261,18 +1277,36 @@ class AiterFlashAttentionImpl(AttentionImpl): ) new_key_cache = key_cache.view_as(k_cache_template) new_value_cache = value_cache.view_as(v_cache_template) - rocm_aiter_ops.pa_fwd_asm( + k_qscale = ( + layer._k_scale + if attn_metadata.k_scale is None + else attn_metadata.k_scale + ) + v_qscale = ( + layer._v_scale + if attn_metadata.v_scale is None + else attn_metadata.v_scale + ) + rocm_aiter_ops.paged_attention_common( Q=query[:num_decode_tokens], K=new_key_cache, V=new_value_cache, + tmp_out=tmp_out, + max_logits=max_logits, + exp_sums=exp_sums, + max_seq_len=attn_metadata.max_seq_len, block_tables=attn_metadata.block_table[:num_decodes], context_lens=attn_metadata.seq_lens[:num_decodes], block_tables_stride0=attn_metadata.block_table[ :num_decodes ].stride(0), - K_QScale=attn_metadata.k_scale, - V_QScale=attn_metadata.v_scale, + scale=self.scale, + K_QScale_hip=k_qscale, + V_QScale_hip=v_qscale, + K_QScale_asm=k_qscale, + V_QScale_asm=v_qscale, out_=output[:num_decode_tokens], + kv_cache_dtype=self.kv_cache_dtype, ) else: _, num_heads, head_size = query.shape