[Attention] Refactor AttentionMetadata Preparation for Encoder-only Models (#23154)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-08-21 22:05:59 -07:00
committed by GitHub
parent 5964069367
commit 17373dcd93
12 changed files with 226 additions and 214 deletions

View File

@@ -8,6 +8,7 @@ import time
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional, Union, cast
import numpy as np
@@ -62,9 +63,10 @@ from vllm.v1.attention.backends.utils import (
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheSpec, MambaSpec,
SlidingWindowSpec)
KVCacheGroupSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
LogprobsTensors, ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
@@ -136,7 +138,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cache_config.cache_dtype]
self.is_pooling_model = model_config.pooler_config is not None
self.is_encoder_only_model = False
self.is_multimodal_raw_input_supported = (
model_config.is_multimodal_raw_input_supported)
self.max_model_len = model_config.max_model_len
@@ -345,6 +346,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.reorder_batch_threshold: Optional[int] = None
# Attention layers that are only in the KVCacheConfig of the runner
# (e.g., KV sharing, encoder-only attention), but not in the
# KVCacheConfig of the scheduler.
self.runner_only_attn_layers: set[str] = set()
# Cached outputs.
self._draft_token_ids: Optional[Union[list[list[int]],
torch.Tensor]] = None
@@ -834,23 +840,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata: dict[str, Any] = {}
# Prepare encoder attention metadata separately
# (encoder layers are not in KV cache groups)
if self.is_encoder_only_model:
per_layer_metadata = \
self._build_encoder_only_attn_metadata(
scheduler_output)
# Add encoder attention metadata for all encoder layers
attention_layers = get_layers_from_vllm_config(
self.vllm_config, Attention)
for layer_name, attn_module in attention_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
common_attn_metadata, encoder_attn_metadata =\
per_layer_metadata[layer_name]
attn_metadata[layer_name] = encoder_attn_metadata
# Used in the below loop.
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
@@ -863,13 +852,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens]
if isinstance(kv_cache_group_spec.kv_cache_spec,
EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to
# create a dummy block table and slot mapping for them.
blk_table_tensor = torch.zeros(
(num_reqs, 1),
dtype=torch.int32,
pin_memory=self.pin_memory,
device="cpu").to(self.device, non_blocking=True)
slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
dtype=torch.int32,
pin_memory=self.pin_memory,
device="cpu").to(self.device,
non_blocking=True)
num_common_prefix_blocks = 0
else:
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
slot_mapping = blk_table.slot_mapping[:
total_num_scheduled_tokens]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode.
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode.
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
num_common_prefix_blocks = (
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id])
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
@@ -897,8 +906,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
num_common_prefix_blocks,
kv_cache_group_spec.kv_cache_spec,
builder,
)
@@ -2812,49 +2820,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Calculate reorder batch threshold (if neeeded)
self.calculate_reorder_batch_threshold()
if len(self.attn_groups) > 0:
return
# Check if model is encoder-only
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
for layer_name, attn_module in attn_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
if attn_module.sliding_window is None:
attn_spec: AttentionSpec = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
else:
attn_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
use_mla=use_mla)
attn_specs[attn_spec].append(layer_name)
else:
raise ValueError("Expected only encoder-only layers")
if len(attn_specs) > 0:
total_layers = 0
for attn_spec, layer_names in attn_specs.items():
attn_backends = get_attn_backends_for_layers(layer_names)
total_layers += len(layer_names)
self.attn_groups.append(
create_attn_groups(attn_backends, attn_spec))
assert total_layers == len(attn_layers), \
"All or none of the layers are expected to be encoder-only"
self.is_encoder_only_model = True
def initialize_cudagraph_capture(self) -> None:
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_builder_name = None
@@ -3002,7 +2967,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
layer_names.update(group.layer_names)
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys(
)), "Some layers are not correctly initialized"
return kv_cache_raw_tensors
@@ -3040,6 +3008,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_tensor.numel() //
@@ -3161,6 +3131,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config.kv_cache_groups,
kv_caches,
self.attn_groups,
self.runner_only_attn_layers,
)
attn_layers = get_layers_from_vllm_config(self.vllm_config,
Attention)
@@ -3185,8 +3156,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
self.may_reinitialize_input_batch(kv_cache_config)
self.may_add_encoder_only_layers_to_kv_cache_config()
self.initialize_attn_backend(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
@@ -3199,6 +3172,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""
Add encoder-only layers to the KV cache config.
"""
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
encoder_only_attn_specs: dict[AttentionSpec,
list[str]] = defaultdict(list)
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
attn_spec = EncoderOnlyAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
encoder_only_attn_specs[attn_spec].append(layer_name)
self.runner_only_attn_layers.add(layer_name)
if len(encoder_only_attn_specs) > 0:
assert len(
encoder_only_attn_specs
) == 1, "Only support one encoder-only attention spec now"
spec, layer_names = encoder_only_attn_specs.popitem()
self.kv_cache_config.kv_cache_groups.append(
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
@@ -3287,70 +3287,3 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mamba_type=mamba_module.mamba_type)
return kv_cache_spec
def _build_encoder_only_attn_metadata(
self, scheduler_output: "SchedulerOutput") -> \
dict[str, tuple[CommonAttentionMetadata, Any]]:
"""Prepare encoder attention metadata for encoder-only models.
Args:
scheduler_output: Scheduler output
Returns:
dict[str, Any]: Encoder attention metadata
"""
num_reqs = self.input_batch.num_reqs
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# Get the number of scheduled tokens for each request.
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
max_num_scheduled_tokens = max(tokens)
dummy_block_table = torch.zeros((num_reqs, 1),
dtype=torch.int32,
pin_memory=self.pin_memory,
device="cpu").to(self.device,
non_blocking=True)
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
dtype=torch.int32,
pin_memory=self.pin_memory,
device="cpu").to(self.device,
non_blocking=True)
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
for attn_group_list in self.attn_groups:
assert len(attn_group_list) == 1
attn_group = attn_group_list[0]
# Use the first attention metadata builder
# to create encoder attention metadata
builder = attn_group.metadata_builder
common_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
seq_lens=self.seq_lens[:num_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
num_computed_tokens_cpu=self.input_batch.
num_computed_tokens_cpu_tensor[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
max_seq_len=self.seq_lens_cpu[:num_reqs].max().item(),
block_table_tensor=dummy_block_table,
slot_mapping=dummy_slot_mapping,
causal=False,
)
metadata = builder.build(
common_prefix_len=0, # No cascade for encoder
common_attn_metadata=common_metadata,
)
for layer_name in attn_group.layer_names:
group_metadata[layer_name] = (common_metadata, metadata)
return group_metadata