[V1] Enable prefill optimization for Gemma3n (#22628)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import gc
|
||||
import itertools
|
||||
import time
|
||||
@@ -58,7 +57,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
supports_dynamo)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_kv_sharing_fast_prefill_attention_metadata,
|
||||
create_fast_prefill_custom_backend,
|
||||
reorder_batch_to_split_decodes_and_prefills)
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
@@ -84,9 +83,10 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
|
||||
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||
from .utils import (AttentionGroup, MultiModalBudget,
|
||||
add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache,
|
||||
gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
@@ -860,6 +860,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=blk_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
logits_indices_padded=logits_indices_padded,
|
||||
num_logits_indices=logits_indices.size(0),
|
||||
causal=True,
|
||||
)
|
||||
|
||||
@@ -884,28 +886,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
))
|
||||
|
||||
fast_prefill_metadata = attn_metadata_i
|
||||
if (self.cache_config.kv_sharing_fast_prefill
|
||||
and self.kv_sharing_fast_prefill_eligible_layers):
|
||||
# Dynamically create a a dataclass type that inherits
|
||||
# from attention metadata type but includes additional
|
||||
# fields logits_indices_padded and num_logits_indices
|
||||
# which are required for prefill truncation
|
||||
fast_prefill_metadata_type = (
|
||||
make_kv_sharing_fast_prefill_attention_metadata(
|
||||
metadata_cls=type(attn_metadata_i), ))
|
||||
fast_prefill_metadata = fast_prefill_metadata_type(
|
||||
**dataclasses.asdict(attn_metadata_i),
|
||||
logits_indices_padded=logits_indices_padded,
|
||||
num_logits_indices=logits_indices.size(0),
|
||||
)
|
||||
|
||||
for layer_name in attn_group.layer_names:
|
||||
if (self.cache_config.kv_sharing_fast_prefill
|
||||
and layer_name
|
||||
in self.kv_sharing_fast_prefill_eligible_layers):
|
||||
attn_metadata[layer_name] = fast_prefill_metadata
|
||||
continue
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
# Hot-Swap lora model
|
||||
@@ -1484,6 +1465,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return self.kv_connector_no_forward(scheduler_output,
|
||||
self.vllm_config)
|
||||
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
assert not self.input_batch.num_prompt_logprobs, (
|
||||
"--kv-sharing-fast-prefill produces incorrect logprobs for "
|
||||
"prompt tokens, tokens, please disable it when the requests "
|
||||
"need prompt logprobs")
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||
@@ -2742,6 +2729,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# layer.
|
||||
for layer_name in layer_names:
|
||||
attn_backend = layers[layer_name].get_attn_backend()
|
||||
|
||||
if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
|
||||
attn_backend = create_fast_prefill_custom_backend(
|
||||
"FastPrefill",
|
||||
attn_backend,
|
||||
)
|
||||
|
||||
key = attn_backend.full_cls_name()
|
||||
attn_backends[key] = attn_backend
|
||||
attn_backend_layers[key].append(layer_name)
|
||||
@@ -3074,20 +3068,40 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
|
||||
kv_cache_raw_tensors)
|
||||
|
||||
# Setup `kv_cache_config` and `kv_caches` for models
|
||||
# with cross-layer KV sharing
|
||||
if self.shared_kv_cache_layers:
|
||||
initialize_kv_cache_for_kv_sharing(
|
||||
self.shared_kv_cache_layers,
|
||||
kv_cache_config.kv_cache_groups,
|
||||
kv_caches,
|
||||
self.attn_groups,
|
||||
self.runner_only_attn_layers,
|
||||
)
|
||||
# Set up cross-layer KV cache sharing
|
||||
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
|
||||
):
|
||||
logger.debug("%s reuses KV cache of %s", layer_name,
|
||||
target_layer_name)
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
return kv_caches
|
||||
|
||||
def maybe_add_kv_sharing_layers_to_kv_cache_groups(
|
||||
self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Add layers that re-use KV cache to KV cache group of its target layer.
|
||||
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
|
||||
"""
|
||||
if not self.shared_kv_cache_layers:
|
||||
# No cross-layer KV sharing, return
|
||||
return
|
||||
|
||||
add_kv_sharing_layers_to_kv_cache_groups(
|
||||
self.shared_kv_cache_layers,
|
||||
kv_cache_config.kv_cache_groups,
|
||||
self.runner_only_attn_layers,
|
||||
)
|
||||
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
# In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
|
||||
# similar KV sharing setups, only the layers that generate KV caches
|
||||
# are involved in the prefill phase, enabling prefill to early exit.
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
Attention)
|
||||
# Iterate in reversed order and add layers that re-use KV cache
|
||||
# e.g. in YOCO-like KV sharing setups (e.g. Gemma3n)
|
||||
for layer_name in reversed(attn_layers):
|
||||
if layer_name in self.shared_kv_cache_layers:
|
||||
self.kv_sharing_fast_prefill_eligible_layers.add(
|
||||
@@ -3095,11 +3109,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
break
|
||||
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
return kv_caches
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
@@ -3111,6 +3120,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user