[Attention] Make local attention backend agnostic (#21093)

This commit is contained in:
Lucas Wilkinson
2025-07-18 00:10:42 -04:00
committed by GitHub
parent b9a21e9173
commit 89cab4d01f
8 changed files with 94 additions and 242 deletions

View File

@@ -18,9 +18,8 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
make_local_attention_virtual_batches)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
@@ -55,18 +54,6 @@ class TritonAttentionMetadata:
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
# for local attention
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor
local_seqused_k: torch.Tensor
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_scheduler_metadata: Optional[torch.Tensor]
local_attn_metadata: Optional[LocalAttentionMetadata] = None
class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
@@ -111,34 +98,6 @@ class TritonAttentionMetadataBuilder(
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
# for local attention
local_attn_metadata = None
if self.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.attention_chunk_size,
common_attn_metadata.query_start_loc_cpu.numpy(),
common_attn_metadata.seq_lens_cpu.numpy(),
block_table_tensor,
self.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max().item()
local_max_seq_len = virt_k_seqlens_np.max().item()
local_attn_metadata = TritonAttentionMetadata \
.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=virt_block_table_tensor,
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_scheduler_metadata=None,
)
use_cascade = common_prefix_len > 0
if use_cascade:
@@ -170,7 +129,6 @@ class TritonAttentionMetadataBuilder(
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata,
)
return attn_metadata
@@ -384,23 +342,11 @@ class TritonAttentionImpl(AttentionImpl):
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
seqused_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
if use_prefill_decode_attn:
# Compute attention and update output up to `num_actual_tokens`.