# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" from dataclasses import dataclass from typing import ClassVar import torch from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, ) from vllm.platforms import current_platform from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, AttentionImpl, AttentionLayer, AttentionMetadataBuilder, AttentionType, CommonAttentionMetadata, MultipleOf, ) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.ops.chunked_prefill_paged_decode import ( chunked_prefill_paged_decode, ) from vllm.v1.attention.ops.paged_attn import PagedAttention from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash, ) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @dataclass class RocmAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| num_actual_tokens: int # Number of tokens excluding padding. max_query_len: int query_start_loc: torch.Tensor max_seq_len: int seq_lens: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor # For cascade attention. use_cascade: bool common_prefix_len: int cu_prefix_query_lens: torch.Tensor | None prefix_kv_lens: torch.Tensor | None suffix_kv_lens: torch.Tensor | None # Optional aot scheduling scheduler_metadata: torch.Tensor | None = None prefix_scheduler_metadata: torch.Tensor | None = None class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS def __init__( self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.block_size = kv_cache_spec.block_size model_config = vllm_config.model_config self.num_heads_q = model_config.get_num_attention_heads( vllm_config.parallel_config ) self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> RocmAttentionMetadata: attn_metadata = self.build(0, common_attn_metadata) # When doing full graph capture, setting seq_lens to # max_model_len will cause graph capture to be extremely # slow, so here we set it to 1. attn_metadata.seq_lens.fill_(1) # Here we set the query start locs to 0. This is to # cover up an invalid memory access in the prefix_prefil kernel # that we run into during graph capture (#25985) common_attn_metadata.query_start_loc.zero_() common_attn_metadata.query_start_loc_cpu.zero_() return attn_metadata def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> RocmAttentionMetadata: num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping use_cascade = common_prefix_len > 0 if use_cascade: cu_prefix_query_lens = torch.tensor( [0, num_actual_tokens], dtype=torch.int32, device=self.device ) prefix_kv_lens = torch.tensor( [common_prefix_len], dtype=torch.int32, device=self.device ) suffix_kv_lens = common_attn_metadata.seq_lens.cpu() - common_prefix_len suffix_kv_lens = suffix_kv_lens.to(self.device) else: cu_prefix_query_lens = None prefix_kv_lens = None suffix_kv_lens = None prefix_scheduler_metadata = None attn_metadata = RocmAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table=block_table_tensor, slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, cu_prefix_query_lens=cu_prefix_query_lens, prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, ) return attn_metadata class RocmAttentionBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [ torch.float16, torch.bfloat16, torch.float32, ] @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: # ROCM paged attention kernel only supports block sizes 16 and 32 # due to shared memory (LDS) constraints on AMD GPUs. # See csrc/rocm/attention.cu CALL_CUSTOM_LAUNCHER_BLK macro. # However, The limitations in [16, 32] are reasonable for a native C++ kernel, # but vLLM should allow support for non-standard sizes via the Triton path, # as addressed in this PR: https://github.com/vllm-project/vllm/pull/31380, # where the Triton kernel under rocm_atten does not support inference # for a non-standard qwen3-next model with a block_size of 544. # We have fixed the Triton kernel so that the standard model uses the original # bit-addressing logic, while the non-standard model # uses our optimized kernel logic. return [16, 32, 544] @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @classmethod def validate_head_size(cls, head_size: int) -> None: if not cls.supports_head_size(head_size): attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {cls.get_supported_head_sizes()}. " "Set --attention-backend=FLEX_ATTENTION to use " "FlexAttention backend which supports all head sizes." ) forward_includes_kv_cache_update: bool = False @staticmethod def get_name() -> str: return "ROCM_ATTN" @staticmethod def get_impl_cls() -> type["RocmAttentionImpl"]: return RocmAttentionImpl @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: return False @staticmethod def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]: return RocmAttentionMetadataBuilder class RocmAttentionImpl(AttentionImpl): def fused_output_quant_supported(self, quant_key: QuantKey): return quant_key == kFp8StaticTensorSym def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: list[float] | None, sliding_window: int | None, kv_cache_dtype: str, logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: int | None = None, sinks: torch.Tensor | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes if sliding_window is None: self.sliding_window = (-1, -1) else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name self.num_queries_per_kv = self.num_heads // self.num_kv_heads RocmAttentionBackend.validate_head_size(head_size) if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: raise NotImplementedError( "Encoder self-attention is not implemented for RocmAttentionImpl" ) self.fp8_dtype = current_platform.fp8_dtype() self.sinks = sinks if sinks is not None: assert sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " f"heads in the layer. Sinks shape: {sinks.shape}, " f"num_heads: {num_heads}." ) def forward( self, layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache: shape = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." if output_block_scale is not None: raise NotImplementedError( "fused block_scale output quantization is not yet supported" " for RocmAttentionImpl" ) if attn_metadata is None: # Profiling run. return output.fill_(0) assert attn_metadata.use_cascade is False # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead # in this method. For example, `view` and `slice` (or `[:n]`) operations # are surprisingly slow even in the case they do not invoke any GPU ops. # Minimize the PyTorch ops in this method as much as possible. # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) assert layer._q_scale_float == 1.0, ( "A non 1.0 q_scale is not currently supported." ) cu_seqlens_q = attn_metadata.query_start_loc seqused_k = attn_metadata.seq_lens max_seqlen_q = attn_metadata.max_query_len max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table # Compute attention and update output up to `num_actual_tokens`. chunked_prefill_paged_decode( query=query[:num_actual_tokens], key=key[:num_actual_tokens] if key is not None else None, value=value[:num_actual_tokens] if value is not None else None, output=output[:num_actual_tokens], kv_cache_dtype=self.kv_cache_dtype, key_cache=key_cache, value_cache=value_cache, block_table=block_table, query_start_loc=cu_seqlens_q, seq_lens=seqused_k, max_seq_len=max_seqlen_k, max_query_len=max_seqlen_q, k_scale=layer._k_scale, v_scale=layer._v_scale, alibi_slopes=self.alibi_slopes, sliding_window=self.sliding_window[0], sm_scale=self.scale, output_scale=output_scale, sinks=self.sinks, ) return output def do_kv_cache_update( self, layer: AttentionLayer, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, ): key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size ) # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached # in KV cache. if ( self.kv_sharing_target_layer_name is None and key is not None and value is not None ): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # Get the actual block_size from value_cache # value_cache shape: [num_blocks, num_heads, head_size, block_size] block_size = value_cache.shape[3] # Determine if it is a power of 2 is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0) if is_pow2: # Normal 16, 32, 64, etc., use vLLM native HIP C++ logic PagedAttention.write_to_paged_cache( key, value, key_cache, value_cache, slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale, ) else: # Case B: Non-standard blocks (e.g., 544 in Qwen3), # force using our modified Triton logic triton_reshape_and_cache_flash( key, value, key_cache, value_cache, slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale, )