[Attention] Make local attention backend agnostic (#21093)

This commit is contained in:
Lucas Wilkinson
2025-07-18 00:10:42 -04:00
committed by GitHub
parent b9a21e9173
commit 89cab4d01f
8 changed files with 94 additions and 242 deletions

View File

@@ -13,8 +13,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (
make_local_attention_virtual_batches)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -201,9 +199,7 @@ class AiterFlashAttentionMetadataBuilder:
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
total_tokens = int(common_attn_metadata.seq_lens_cpu.sum())
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
@@ -215,56 +211,6 @@ class AiterFlashAttentionMetadataBuilder:
dtype=cu_seq_lens.dtype,
out=cu_seq_lens[1:])
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
return None
# for local attention
local_attn_metadata = None
if self.model_config.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.model_config.attention_chunk_size,
query_start_loc_cpu.numpy(),
seq_lens_cpu.numpy(),
block_table_tensor,
self.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max().item()
local_max_seq_len = virt_k_seqlens_np.max().item()
local_scheduler_metadata = schedule(
batch_size=local_query_start_loc.shape[0] - 1,
cu_query_lens=local_query_start_loc,
max_query_len=local_max_query_len,
seqlens=local_seqused_k,
max_seq_len=local_max_seq_len,
causal=True)
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
dtype=torch.int32,
device=self.device)
local_cu_seq_lens[1:] = torch.cumsum(
torch.from_numpy(virt_k_seqlens_np).to(device=self.device,
dtype=torch.int32,
non_blocking=True),
dim=0)
local_attn_metadata = \
AiterFlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=virt_block_table_tensor,
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_cu_seq_lens=local_cu_seq_lens,
local_scheduler_metadata=local_scheduler_metadata,
)
use_cascade = common_prefix_len > 0
cu_prefix_query_lens = None
@@ -286,7 +232,6 @@ class AiterFlashAttentionMetadataBuilder:
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata,
)
return attn_metadata
@@ -377,19 +322,6 @@ class AiterFlashAttentionMetadata:
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# for local attention
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor
local_seqused_k: torch.Tensor
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_cu_seq_lens: torch.Tensor
local_scheduler_metadata: Optional[torch.Tensor]
local_attn_metadata: Optional[LocalAttentionMetadata] = None
class AiterFlashAttentionImpl(AttentionImpl):
@@ -521,25 +453,12 @@ class AiterFlashAttentionImpl(AttentionImpl):
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# Compute attention and update output up to `num_actual_tokens`.
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
if not attn_metadata.use_cascade or use_local_attn:
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
seqused_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
if max_seqlen_q > 1:
cu_seq_lens = attn_metadata.cu_seq_lens
@@ -557,9 +476,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
local_metadata.local_cu_seq_lens),
)
cu_seqlens_k=cu_seq_lens)
_, num_heads, head_size = query.shape
_PARTITION_SIZE_ROCM = 256