[Attention] Refactor attention metadata builder interface (#20466)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-07-17 00:44:25 -04:00
committed by GitHub
parent 28a6d5423d
commit 76b494444f
18 changed files with 1441 additions and 772 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with AiterFlashAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import Any, Optional
import torch
@@ -10,18 +10,13 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (
make_local_attention_virtual_batches)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if current_platform.is_rocm():
import aiter
@@ -172,54 +167,49 @@ logger = init_logger(__name__)
class AiterFlashAttentionMetadataBuilder:
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
model_config = runner.model_config
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.device = device
self.runner = runner
self.num_heads_q = model_config.get_num_attention_heads(
runner.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)
self.num_heads_kv = self.model_config.get_num_kv_heads(
self.parallel_config)
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
def reorder_batch(self, input_batch, scheduler_output) -> bool:
return False
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> 'AiterFlashAttentionMetadata':
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
total_tokens = int(self.runner.seq_lens_np[:num_reqs].sum())
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
total_tokens = int(common_attn_metadata.seq_lens_cpu.sum())
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
dtype=torch.int32,
device="cuda")
device=self.device)
torch.cumsum(seq_lens,
dim=0,
dtype=cu_seq_lens.dtype,
@@ -231,21 +221,21 @@ class AiterFlashAttentionMetadataBuilder:
# for local attention
local_attn_metadata = None
if self.runner.attention_chunk_size is not None:
if self.model_config.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.runner.attention_chunk_size,
self.runner.query_start_loc_np[:num_reqs + 1],
self.runner.seq_lens_np[:num_reqs],
self.model_config.attention_chunk_size,
query_start_loc_cpu.numpy(),
seq_lens_cpu.numpy(),
block_table_tensor,
self.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.runner.device, non_blocking=True)
self.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True)
local_max_query_len = int(seqlens_q_local_np.max())
local_max_seq_len = int(virt_k_seqlens_np.max())
self.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max().item()
local_max_seq_len = virt_k_seqlens_np.max().item()
local_scheduler_metadata = schedule(
batch_size=local_query_start_loc.shape[0] - 1,
cu_query_lens=local_query_start_loc,
@@ -256,12 +246,11 @@ class AiterFlashAttentionMetadataBuilder:
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
dtype=torch.int32,
device=self.runner.device)
device=self.device)
local_cu_seq_lens[1:] = torch.cumsum(
torch.from_numpy(virt_k_seqlens_np).to(
device=self.runner.device,
dtype=torch.int32,
non_blocking=True),
torch.from_numpy(virt_k_seqlens_np).to(device=self.device,
dtype=torch.int32,
non_blocking=True),
dim=0)