# SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type import torch # Required to register custom ops. import torch_xla.experimental.custom_kernel # noqa: F401 from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState NUM_QUERIES_PER_BLOCK = 16 NUM_KV_PAGES_PER_BLOCK = 128 class PallasAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: return "PALLAS_VLLM_V1" @staticmethod def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: return PallasAttentionBackendImpl @staticmethod def get_metadata_cls() -> Type["PallasMetadata"]: return PallasMetadata @staticmethod def get_state_cls() -> Type["CommonAttentionState"]: return CommonAttentionState @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: return (num_kv_heads, num_blocks, block_size, head_size) @staticmethod def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: raise RuntimeError("swap_blocks is not used for the TPU backend.") @dataclass class PallasMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| # Used in the PallasAttentionBackendImpl slot_mapping: torch.Tensor block_tables: torch.Tensor context_lens: torch.Tensor query_start_loc: torch.Tensor num_seqs: int class PallasAttentionBackendImpl(AttentionImpl): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, ) -> None: if blocksparse_params is not None: raise ValueError("Paged attention Pallas kernel does " "not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if head_size % 128 != 0: raise NotImplementedError("Head size must be a multiple of 128.") if alibi_slopes is not None: raise NotImplementedError("Alibi slopes is not supported.") if sliding_window is not None: raise NotImplementedError("Sliding window is not supported.") if kv_cache_dtype != "auto": raise NotImplementedError("FP8 KV cache dtype is not supported.") if blocksparse_params is not None: raise NotImplementedError("Blocksparse is not supported.") if logits_soft_cap is not None: raise NotImplementedError( "Attention logits soft-capping is not supported.") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " "PallasAttentionBackendImpl") def forward( self, layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], attn_metadata: PallasMetadata, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = ([num_kv_heads, num_blocks, block_size, head_size], [num_kv_heads, num_blocks, block_size, head_size]) attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ # For determine_available_memory case. if kv_cache[0].numel() == 0: if output is None: output = torch.ones_like(query) return output assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 num_tokens, hidden_size = query.shape query = query.view(num_tokens, self.num_heads, self.head_size) key = key.view(num_tokens, self.num_kv_heads, self.head_size) value = value.view(num_tokens, self.num_kv_heads, self.head_size) key_cache, value_cache = kv_cache if kv_cache[0].numel() > 0: slot_mapping = attn_metadata.slot_mapping write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) query = query * self.scale output = torch.ops.xla.ragged_paged_attention( query, key_cache, value_cache, attn_metadata.context_lens, attn_metadata.block_tables, attn_metadata.query_start_loc, attn_metadata.num_seqs, num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, num_queries_per_block=NUM_QUERIES_PER_BLOCK, use_kernel=False, ) return output.reshape(num_tokens, hidden_size) def write_to_kv_cache( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slot_mapping: torch.Tensor, ) -> None: """ Write the key and values to the KV cache. Args: key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] k_cache = [num_kv_heads, num_blocks, block_size, head_size] v_cache = [num_kv_heads, num_blocks, block_size, head_size] """ torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) key = key.flatten(0, 1) value = value.flatten(0, 1) key_cache = key_cache.flatten(0, 2) value_cache = value_cache.flatten(0, 2) key_cache.index_copy_(0, slot_mapping, key) value_cache.index_copy_(0, slot_mapping, value)