[Attention] Refactor AttentionMetadata Preparation for Encoder-only Models (#23154)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
@@ -62,9 +63,10 @@ from vllm.v1.attention.backends.utils import (
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
ChunkedLocalAttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec, MambaSpec,
|
||||
SlidingWindowSpec)
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
MambaSpec, SlidingWindowSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
||||
LogprobsTensors, ModelRunnerOutput)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
@@ -136,7 +138,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cache_config.cache_dtype]
|
||||
|
||||
self.is_pooling_model = model_config.pooler_config is not None
|
||||
self.is_encoder_only_model = False
|
||||
self.is_multimodal_raw_input_supported = (
|
||||
model_config.is_multimodal_raw_input_supported)
|
||||
self.max_model_len = model_config.max_model_len
|
||||
@@ -345,6 +346,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
self.reorder_batch_threshold: Optional[int] = None
|
||||
|
||||
# Attention layers that are only in the KVCacheConfig of the runner
|
||||
# (e.g., KV sharing, encoder-only attention), but not in the
|
||||
# KVCacheConfig of the scheduler.
|
||||
self.runner_only_attn_layers: set[str] = set()
|
||||
|
||||
# Cached outputs.
|
||||
self._draft_token_ids: Optional[Union[list[list[int]],
|
||||
torch.Tensor]] = None
|
||||
@@ -834,23 +840,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
|
||||
# Prepare encoder attention metadata separately
|
||||
# (encoder layers are not in KV cache groups)
|
||||
if self.is_encoder_only_model:
|
||||
|
||||
per_layer_metadata = \
|
||||
self._build_encoder_only_attn_metadata(
|
||||
scheduler_output)
|
||||
|
||||
# Add encoder attention metadata for all encoder layers
|
||||
attention_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attention_layers.items():
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
common_attn_metadata, encoder_attn_metadata =\
|
||||
per_layer_metadata[layer_name]
|
||||
attn_metadata[layer_name] = encoder_attn_metadata
|
||||
|
||||
# Used in the below loop.
|
||||
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
|
||||
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
||||
@@ -863,13 +852,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
|
||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens]
|
||||
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
||||
EncoderOnlyAttentionSpec):
|
||||
# Encoder-only layers do not have KV cache, so we need to
|
||||
# create a dummy block table and slot mapping for them.
|
||||
blk_table_tensor = torch.zeros(
|
||||
(num_reqs, 1),
|
||||
dtype=torch.int32,
|
||||
pin_memory=self.pin_memory,
|
||||
device="cpu").to(self.device, non_blocking=True)
|
||||
slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
|
||||
dtype=torch.int32,
|
||||
pin_memory=self.pin_memory,
|
||||
device="cpu").to(self.device,
|
||||
non_blocking=True)
|
||||
num_common_prefix_blocks = 0
|
||||
else:
|
||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = blk_table.slot_mapping[:
|
||||
total_num_scheduled_tokens]
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode.
|
||||
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode.
|
||||
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||
num_common_prefix_blocks = (
|
||||
scheduler_output.
|
||||
num_common_prefix_blocks[kv_cache_group_id])
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
@@ -897,8 +906,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.
|
||||
num_common_prefix_blocks[kv_cache_group_id],
|
||||
num_common_prefix_blocks,
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
builder,
|
||||
)
|
||||
@@ -2812,49 +2820,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Calculate reorder batch threshold (if neeeded)
|
||||
self.calculate_reorder_batch_threshold()
|
||||
|
||||
if len(self.attn_groups) > 0:
|
||||
return
|
||||
|
||||
# Check if model is encoder-only
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
if attn_module.sliding_window is None:
|
||||
attn_spec: AttentionSpec = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
else:
|
||||
attn_spec = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla)
|
||||
attn_specs[attn_spec].append(layer_name)
|
||||
|
||||
else:
|
||||
raise ValueError("Expected only encoder-only layers")
|
||||
|
||||
if len(attn_specs) > 0:
|
||||
total_layers = 0
|
||||
for attn_spec, layer_names in attn_specs.items():
|
||||
|
||||
attn_backends = get_attn_backends_for_layers(layer_names)
|
||||
total_layers += len(layer_names)
|
||||
|
||||
self.attn_groups.append(
|
||||
create_attn_groups(attn_backends, attn_spec))
|
||||
assert total_layers == len(attn_layers), \
|
||||
"All or none of the layers are expected to be encoder-only"
|
||||
self.is_encoder_only_model = True
|
||||
|
||||
def initialize_cudagraph_capture(self) -> None:
|
||||
min_cg_support = AttentionCGSupport.ALWAYS
|
||||
min_cg_builder_name = None
|
||||
@@ -3002,7 +2967,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
layer_names.update(group.layer_names)
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
layer_names.add(layer_name)
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys(
|
||||
)), "Some layers are not correctly initialized"
|
||||
return kv_cache_raw_tensors
|
||||
@@ -3040,6 +3008,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
attn_backend = group.backend
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = (raw_tensor.numel() //
|
||||
@@ -3161,6 +3131,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_cache_config.kv_cache_groups,
|
||||
kv_caches,
|
||||
self.attn_groups,
|
||||
self.runner_only_attn_layers,
|
||||
)
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
Attention)
|
||||
@@ -3185,8 +3156,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
|
||||
@@ -3199,6 +3172,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
|
||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||
"""
|
||||
Add encoder-only layers to the KV cache config.
|
||||
"""
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
encoder_only_attn_specs: dict[AttentionSpec,
|
||||
list[str]] = defaultdict(list)
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
attn_spec = EncoderOnlyAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
encoder_only_attn_specs[attn_spec].append(layer_name)
|
||||
self.runner_only_attn_layers.add(layer_name)
|
||||
if len(encoder_only_attn_specs) > 0:
|
||||
assert len(
|
||||
encoder_only_attn_specs
|
||||
) == 1, "Only support one encoder-only attention spec now"
|
||||
spec, layer_names = encoder_only_attn_specs.popitem()
|
||||
self.kv_cache_config.kv_cache_groups.append(
|
||||
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
@@ -3287,70 +3287,3 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
mamba_type=mamba_module.mamba_type)
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def _build_encoder_only_attn_metadata(
|
||||
self, scheduler_output: "SchedulerOutput") -> \
|
||||
dict[str, tuple[CommonAttentionMetadata, Any]]:
|
||||
"""Prepare encoder attention metadata for encoder-only models.
|
||||
|
||||
Args:
|
||||
scheduler_output: Scheduler output
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Encoder attention metadata
|
||||
"""
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
req_ids = self.input_batch.req_ids
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
|
||||
dummy_block_table = torch.zeros((num_reqs, 1),
|
||||
dtype=torch.int32,
|
||||
pin_memory=self.pin_memory,
|
||||
device="cpu").to(self.device,
|
||||
non_blocking=True)
|
||||
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
|
||||
dtype=torch.int32,
|
||||
pin_memory=self.pin_memory,
|
||||
device="cpu").to(self.device,
|
||||
non_blocking=True)
|
||||
|
||||
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
|
||||
|
||||
for attn_group_list in self.attn_groups:
|
||||
|
||||
assert len(attn_group_list) == 1
|
||||
attn_group = attn_group_list[0]
|
||||
|
||||
# Use the first attention metadata builder
|
||||
# to create encoder attention metadata
|
||||
builder = attn_group.metadata_builder
|
||||
|
||||
common_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||
seq_lens=self.seq_lens[:num_reqs],
|
||||
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
||||
num_computed_tokens_cpu=self.input_batch.
|
||||
num_computed_tokens_cpu_tensor[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
max_seq_len=self.seq_lens_cpu[:num_reqs].max().item(),
|
||||
block_table_tensor=dummy_block_table,
|
||||
slot_mapping=dummy_slot_mapping,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
metadata = builder.build(
|
||||
common_prefix_len=0, # No cascade for encoder
|
||||
common_attn_metadata=common_metadata,
|
||||
)
|
||||
|
||||
for layer_name in attn_group.layer_names:
|
||||
group_metadata[layer_name] = (common_metadata, metadata)
|
||||
|
||||
return group_metadata
|
||||
|
||||
Reference in New Issue
Block a user