[Performance] Split FlashAttn attention and cache update (#25954)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Luka Govedič <luka.govedic@gmail.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
ElizaWszola
2026-01-24 02:28:06 +01:00
committed by GitHub
parent 0118cdcc02
commit a28b94e6ef
21 changed files with 458 additions and 68 deletions

View File

@@ -390,29 +390,43 @@ class Attention(nn.Module, AttentionLayerBase):
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size_v)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata, output=output
kv_cache_dummy_dep = None
if not self.attn_backend.forward_includes_kv_cache_update:
kv_cache_dummy_dep = unified_kv_cache_update(
key, value, self.layer_name
)
unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
else:
kv_cache_dummy_dep = None
if not self.attn_backend.forward_includes_kv_cache_update and (
# torch can only dispatch custom op if a tensor is passed
key is not None or value is not None
):
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
key, value, self.layer_name
)
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name
query,
key,
value,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output.view(-1, hidden_size)
else:
assert self.attn_backend.forward_includes_kv_cache_update, (
"Split KV cache update not supported when output tensor not provided."
)
if self.use_direct_call:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata
)
return unified_attention(query, key, value, self.layer_name)
else:
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name
@@ -802,6 +816,55 @@ direct_register_custom_op(
)
def unified_kv_cache_update(
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
"""
Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
forward_context = get_forward_context()
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
layer_slot_mapping = slot_mapping.get(layer_name)
if layer_slot_mapping is not None:
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
)
attn_layer.impl.do_kv_cache_update(
attn_layer,
key,
value,
kv_cache,
layer_slot_mapping,
)
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
def unified_kv_cache_update_fake(
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty(0, device=key.device, dtype=key.dtype)
direct_register_custom_op(
op_name="unified_kv_cache_update",
op_func=unified_kv_cache_update,
fake_impl=unified_kv_cache_update_fake,
mutates_args=[],
)
@maybe_transfer_kv_layer
def unified_attention_with_output(
query: torch.Tensor,
@@ -811,7 +874,12 @@ def unified_attention_with_output(
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> None:
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del kv_cache_dummy_dep
attn_metadata, self, kv_cache = get_attention_context(layer_name)
self.impl.forward(
@@ -835,6 +903,7 @@ def unified_attention_with_output_fake(
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> None:
return

View File

@@ -189,6 +189,7 @@ class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context
no_compile_layers: dict[str, Any]
attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]
"""
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
attention layer to its attention metadata
@@ -266,6 +267,7 @@ def create_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None,
slot_mapping: dict[str, torch.Tensor] | None = None,
additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False,
):
@@ -282,6 +284,7 @@ def create_forward_context(
remaining_moe_layers=remaining_moe_layers,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
slot_mapping=slot_mapping or {},
dp_metadata=dp_metadata,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
@@ -316,6 +319,7 @@ def set_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None,
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
skip_compiled: bool = False,
):
"""A context manager that stores the current forward context,
@@ -374,6 +378,7 @@ def set_forward_context(
cudagraph_runtime_mode,
batch_descriptor,
ubatch_slices,
slot_mapping,
additional_kwargs,
skip_compiled,
)

View File

@@ -15,7 +15,7 @@ from vllm.v1.attention.backend import (
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
subclass_attention_backend,
subclass_attention_backend_with_overrides,
)
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
@@ -72,6 +72,7 @@ def create_cross_attention_backend(
) -> type[AttentionBackend]:
prefix = "CrossAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
underlying_impl = underlying_attn_backend.get_impl_cls()
class CrossAttentionBuilder(underlying_builder): # type: ignore
def build(
@@ -106,18 +107,60 @@ def create_cross_attention_backend(
)
# NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
new_metadata.slot_mapping = _get_cross_slot_mapping(
slot_mapping = _get_cross_slot_mapping(
new_metadata.encoder_seq_lens_cpu,
new_metadata.block_table_tensor,
self.kv_cache_spec,
self.device,
)
return super().build(common_prefix_len, new_metadata, fast_build)
attn_metadata = super().build(common_prefix_len, new_metadata, fast_build)
attn_metadata.slot_mapping = slot_mapping
return attn_metadata
attn_backend = subclass_attention_backend(
# NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by
# `CrossAttentionBuilder` instead of the one computed by `BlockTable`
# (gpu_model_runner)
class CrossAttentionImpl(underlying_impl): # type: ignore[valid-type,misc]
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
if (
not underlying_attn_backend.forward_includes_kv_cache_update
and attn_metadata is not None
):
self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping
)
return super().forward(
layer,
query,
key,
value,
kv_cache,
attn_metadata,
output,
output_scale,
output_block_scale,
)
attn_backend = subclass_attention_backend_with_overrides(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=CrossAttentionBuilder,
overrides={
"get_builder_cls": lambda: CrossAttentionBuilder,
"get_impl_cls": lambda: CrossAttentionImpl,
"forward_includes_kv_cache_update": True,
},
)
return attn_backend

View File

@@ -613,8 +613,9 @@ def weak_ref_tensor(tensor: Any) -> Any:
Create a weak reference to a tensor.
The new tensor will share the same data as the original tensor,
but will not keep the original tensor alive.
This ignores 0-size tensors as those don't allocate any memory.
"""
if isinstance(tensor, torch.Tensor):
if isinstance(tensor, torch.Tensor) and tensor.numel() > 0:
return torch.ops._C.weak_ref_tensor(tensor)
else:
return tensor

View File

@@ -53,6 +53,9 @@ class AttentionBackend(ABC):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto", "bfloat16"]
# Does attention's forward() include kv cache update?
forward_includes_kv_cache_update: bool = True
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(1)]

View File

@@ -79,6 +79,8 @@ class FlashAttentionBackend(AttentionBackend):
return [16, 32, 64]
return [MultipleOf(16)]
forward_includes_kv_cache_update: bool = False
@staticmethod
def get_name() -> str:
return "FLASH_ATTN"
@@ -652,32 +654,6 @@ class FlashAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
@@ -774,6 +750,49 @@ class FlashAttentionImpl(AttentionImpl):
)
return output
def do_kv_cache_update(
self,
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is not None
or key is None
or value is None
):
return
key_cache, value_cache = kv_cache.unbind(0)
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def _forward_with_dcp(
self,
query: torch.Tensor,

View File

@@ -159,6 +159,10 @@ class SpecDecodeBaseProposer:
with_numpy=True,
)
self._slot_mapping_buffer = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device
)
# Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple | None = None
if current_platform.is_rocm():
@@ -230,6 +234,24 @@ class SpecDecodeBaseProposer:
positions = positions[0]
self.positions[:num_tokens] = positions
def _get_slot_mapping(
self,
num_tokens: int,
slot_mapping: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Return slot_mapping dict for EAGLE layers.
If slot_mapping is provided, copies it into the buffer first.
"""
if slot_mapping is not None:
num_actual = slot_mapping.shape[0]
self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
if num_tokens > num_actual:
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
view = self._slot_mapping_buffer[:num_tokens]
return {name: view for name in self.attn_layer_names + self.indexer_layer_names}
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys for eagle.
@@ -262,6 +284,9 @@ class SpecDecodeBaseProposer:
sampling_metadata: SamplingMetadata,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None,
) -> torch.Tensor:
batch_size = common_attn_metadata.batch_size()
@@ -358,6 +383,9 @@ class SpecDecodeBaseProposer:
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
):
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
@@ -396,6 +424,7 @@ class SpecDecodeBaseProposer:
positions=positions,
hidden_states=hidden_states,
common_attn_metadata=common_attn_metadata,
slot_mappings=slot_mappings,
)
# [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1)
@@ -553,6 +582,9 @@ class SpecDecodeBaseProposer:
num_tokens=input_batch_size,
num_tokens_across_dp=batch_size_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
input_batch_size, common_attn_metadata.slot_mapping
),
):
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
@@ -760,6 +792,9 @@ class SpecDecodeBaseProposer:
# [num_tokens, hidden_size]
hidden_states: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None,
) -> list[torch.Tensor]:
tree_attn_metadata_builder = self.runner.attn_groups[0][
0
@@ -881,6 +916,9 @@ class SpecDecodeBaseProposer:
self.vllm_config,
num_tokens=num_input_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, attn_metadata.slot_mapping
),
):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
@@ -1212,6 +1250,7 @@ class SpecDecodeBaseProposer:
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
# FIXME: when using tree-based specdec, adjust number of forward-passes
# according to the depth of the tree.
@@ -1233,12 +1272,23 @@ class SpecDecodeBaseProposer:
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
# Make sure to use EAGLE's own buffer during cudagraph capture.
if (
self.attn_layer_names
and slot_mappings is not None
and self.attn_layer_names[0] in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
if self.supports_mm_inputs:
input_ids = None

View File

@@ -38,6 +38,9 @@ class MedusaProposer:
self,
target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None, # unused
) -> torch.Tensor:
# Generate blocks and compute logits
blocks = self.model(target_hidden_states)

View File

@@ -3,6 +3,7 @@
import os
import numpy as np
import torch
from numba import get_num_threads, jit, njit, prange, set_num_threads
from vllm.config import VllmConfig
@@ -132,6 +133,9 @@ class NgramProposer:
sampled_token_ids: list[list[int]],
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None, # unused
) -> list[list[int]]:
# find which requests need ngram proposals
valid_ngram_requests = []

View File

@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config import VllmConfig
from vllm.v1.worker.gpu_input_batch import InputBatch
@@ -33,6 +35,9 @@ class SuffixDecodingProposer:
self,
input_batch: InputBatch,
sampled_token_ids: list[list[int]],
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None, # unused
) -> list[list[int]]:
"""
Propose speculative tokens for each request in the input batch. Suffix Decoding

View File

@@ -140,6 +140,18 @@ def init_kv_cache(
return kv_caches
def build_slot_mappings_by_layer(
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
) -> dict[str, torch.Tensor]:
slot_mappings_by_layer: dict[str, torch.Tensor] = {}
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
slot_mapping = slot_mappings[i]
for layer_name in kv_cache_group.layer_names:
slot_mappings_by_layer[layer_name] = slot_mapping
return slot_mappings_by_layer
def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int,

View File

@@ -14,7 +14,10 @@ from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
@@ -88,7 +91,7 @@ class CudaGraphManager:
positions = mrope_positions[:, :num_tokens]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:num_tokens]
attn_metadata = prepare_inputs_to_capture(
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
input_buffers,
@@ -98,6 +101,9 @@ class CudaGraphManager:
kv_cache_config,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, kv_cache_config
)
# Warm up.
with set_forward_context(
@@ -106,6 +112,7 @@ class CudaGraphManager:
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings_by_layer,
):
hidden_states = model(
input_ids=input_ids,
@@ -125,6 +132,7 @@ class CudaGraphManager:
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings_by_layer,
),
torch.cuda.graph(graph, self.pool),
):
@@ -244,7 +252,7 @@ def prepare_inputs_to_capture(
attn_metadata_builders: list[AttentionMetadataBuilder],
max_model_len: int,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
) -> tuple[dict[str, Any], torch.Tensor]:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
@@ -274,4 +282,4 @@ def prepare_inputs_to_capture(
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
)
return attn_metadata
return attn_metadata, slot_mappings

View File

@@ -24,6 +24,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
get_kv_cache_spec,
init_attn_backend,
init_kv_cache,
@@ -881,6 +882,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.uses_mrope:
assert input_batch.mrope_positions is not None
positions = input_batch.mrope_positions
slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.idx_mapping,
input_batch.query_start_loc,
input_batch.positions[: input_batch.num_tokens],
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
@@ -888,6 +897,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): Support piecewise CUDA graph.
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings_by_layer,
):
self.kv_connector.pre_forward(scheduler_output)
hidden_states = self.model(

View File

@@ -314,6 +314,7 @@ class ExecuteModelState(NamedTuple):
aux_hidden_states: list[torch.Tensor] | None
ec_connector_output: ECConnectorOutput | None
cudagraph_stats: CUDAGraphStat | None
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None
class GPUModelRunner(
@@ -1595,6 +1596,7 @@ class GPUModelRunner(
for_cudagraph_capture: bool = False,
num_scheduled_tokens: dict[str, int] | None = None,
cascade_attn_prefix_lens: list[list[int]] | None = None,
slot_mappings: dict[int, torch.Tensor] | None = None,
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
"""
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
@@ -1628,7 +1630,7 @@ class GPUModelRunner(
kv_cache_groups = self.kv_cache_config.kv_cache_groups
def _get_block_table_and_slot_mapping(kv_cache_gid: int):
def _get_block_table(kv_cache_gid: int):
assert num_reqs_padded is not None and num_tokens_padded is not None
kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
@@ -1637,24 +1639,19 @@ class GPUModelRunner(
dtype=torch.int32,
device=self.device,
)
slot_mapping = torch.zeros(
(num_tokens_padded,),
dtype=torch.int64,
device=self.device,
)
else:
blk_table = self.input_batch.block_table[kv_cache_gid]
blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
return blk_table_tensor
return blk_table_tensor, slot_mapping
assert slot_mappings is not None
block_table_gid_0 = _get_block_table(0)
slot_mapping_gid_0 = slot_mappings[0]
block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)
if self.model_config.enable_return_routed_experts:
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
cm_base = CommonAttentionMetadata(
@@ -1779,9 +1776,8 @@ class GPUModelRunner(
num_reqs_padded,
)
if kv_cache_gid > 0:
cm.block_table_tensor, cm.slot_mapping = (
_get_block_table_and_slot_mapping(kv_cache_gid)
)
cm.block_table_tensor = _get_block_table(kv_cache_gid)
cm.slot_mapping = slot_mappings[kv_cache_gid]
if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(self.drafter, EagleProposer):
@@ -3119,6 +3115,80 @@ class GPUModelRunner(
pyt_hooks.register_hooks(self.model, self.model.__class__.__name__)
self.layerwise_nvtx_hooks_registered = True
def _get_slot_mappings(
self,
num_tokens_padded: int,
num_reqs_padded: int,
num_tokens_unpadded: int,
ubatch_slices: "UBatchSlices | None" = None,
) -> tuple[
dict[int, torch.Tensor] | None,
dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
]:
"""
Build slot mappings in both formats needed by the system.
Args:
num_tokens_padded: Total number of tokens (padded)
num_reqs_padded: Total number of requests (padded)
num_tokens_unpadded: Actual number of tokens (unpadded)
ubatch_slices: Optional ubatch slicing info for DBO
Returns:
A tuple of:
- slot_mappings_by_gid: dict[int, torch.Tensor] for attention metadata
- slot_mappings_by_layer: dict[str, torch.Tensor] or list for ForwardContext
"""
if not (
hasattr(self, "kv_cache_config")
and self.kv_cache_config is not None
and len(self.kv_cache_config.kv_cache_groups) > 0
):
return None, None
def _get_slot_mapping(kv_cache_gid: int):
assert num_reqs_padded is not None and num_tokens_padded is not None
kv_cache_spec = self.kv_cache_config.kv_cache_groups[
kv_cache_gid
].kv_cache_spec
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
slot_mapping = torch.zeros(
(num_tokens_padded,),
dtype=torch.int64,
device=self.device,
)
else:
blk_table = self.input_batch.block_table[kv_cache_gid]
slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
slot_mapping[num_tokens_unpadded:num_tokens_padded].fill_(-1)
return slot_mapping
slot_mappings_by_gid = {
gid: _get_slot_mapping(gid)
for gid, _ in enumerate(self.kv_cache_config.kv_cache_groups)
}
slot_mappings_by_layer: dict[str, torch.Tensor] = {}
for gid, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups):
slot_mapping = slot_mappings_by_gid[gid]
for layer_name in kv_cache_group.layer_names:
slot_mappings_by_layer[layer_name] = slot_mapping
if ubatch_slices is not None:
result: list[dict[str, torch.Tensor]] = []
for ubatch in ubatch_slices:
sliced_mappings: dict[str, torch.Tensor] = {}
for layer_name, slot_mapping in slot_mappings_by_layer.items():
sliced_mappings[layer_name] = slot_mapping[ubatch.token_slice]
result.append(sliced_mappings)
return slot_mappings_by_gid, result
return slot_mappings_by_gid, slot_mappings_by_layer
@torch.inference_mode()
def execute_model(
self,
@@ -3248,6 +3318,17 @@ class GPUModelRunner(
ubatch_slices_padded,
)
# True if any attention backend handles KV cache update separately
# from forward() (i.e., forward_includes_kv_cache_update=False). When true,
# slot_mappings must use padded dimensions to match the key/value tensors.
has_separate_kv_update = not all(
all(
g.backend.forward_includes_kv_cache_update
for g in self.attn_groups[id]
)
for id, spec in enumerate(self.kv_cache_config.kv_cache_groups)
if not isinstance(spec.kv_cache_spec, EncoderOnlyAttentionSpec)
)
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
if self.cache_config.mamba_cache_mode == "align":
@@ -3265,6 +3346,17 @@ class GPUModelRunner(
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
slot_mappings_by_group, slot_mappings = self._get_slot_mappings(
num_tokens_padded=num_tokens_padded
if pad_attn or has_separate_kv_update
else num_tokens_unpadded,
num_reqs_padded=(
num_reqs_padded if pad_attn or has_separate_kv_update else num_reqs
),
num_tokens_unpadded=num_tokens_unpadded,
ubatch_slices=ubatch_slices_padded,
)
attn_metadata, spec_decode_common_attn_metadata = (
self._build_attention_metadata(
num_tokens=num_tokens_unpadded,
@@ -3277,6 +3369,7 @@ class GPUModelRunner(
use_spec_decode=use_spec_decode,
num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
slot_mappings=slot_mappings_by_group,
)
)
@@ -3317,6 +3410,7 @@ class GPUModelRunner(
cudagraph_runtime_mode=cudagraph_mode,
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded,
slot_mapping=slot_mappings,
skip_compiled=has_encoder_input,
),
record_function_or_nullcontext("gpu_model_runner: forward"),
@@ -3399,6 +3493,7 @@ class GPUModelRunner(
aux_hidden_states,
ec_connector_output,
cudagraph_stats,
slot_mappings,
)
self.kv_connector_output = kv_connector_output
return None
@@ -3435,6 +3530,7 @@ class GPUModelRunner(
aux_hidden_states,
ec_connector_output,
cudagraph_stats,
slot_mappings,
) = self.execute_model_state
# Clear ephemeral state.
self.execute_model_state = None
@@ -3468,6 +3564,7 @@ class GPUModelRunner(
aux_hidden_states,
spec_decode_metadata,
spec_decode_common_attn_metadata,
slot_mappings,
)
self._copy_draft_token_ids_to_cpu(scheduler_output)
@@ -3676,6 +3773,7 @@ class GPUModelRunner(
aux_hidden_states: list[torch.Tensor] | None,
spec_decode_metadata: SpecDecodeMetadata | None,
common_attn_metadata: CommonAttentionMetadata,
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
) -> list[list[int]] | torch.Tensor:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
spec_config = self.speculative_config
@@ -3687,11 +3785,14 @@ class GPUModelRunner(
sampled_token_ids,
self.input_batch.num_tokens_no_spec,
self.input_batch.token_ids_cpu,
slot_mappings=slot_mappings,
)
elif spec_config.method == "suffix":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, SuffixDecodingProposer)
draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids)
draft_token_ids = self.drafter.propose(
self.input_batch, sampled_token_ids, slot_mappings=slot_mappings
)
elif spec_config.method == "medusa":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, MedusaProposer)
@@ -3716,6 +3817,7 @@ class GPUModelRunner(
draft_token_ids = self.drafter.propose(
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
slot_mappings=slot_mappings,
)
elif spec_config.use_eagle() or spec_config.uses_draft_model():
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
@@ -3826,6 +3928,7 @@ class GPUModelRunner(
common_attn_metadata=common_attn_metadata,
mm_embed_inputs=mm_embed_inputs,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
slot_mappings=slot_mappings,
)
return draft_token_ids
@@ -4406,6 +4509,13 @@ class GPUModelRunner(
attn_metadata: PerLayerAttnMetadata | None = None
slot_mappings_by_group, slot_mappings = self._get_slot_mappings(
num_tokens_padded=num_tokens,
num_reqs_padded=num_reqs_padded,
num_tokens_unpadded=num_tokens_unpadded,
ubatch_slices=ubatch_slices_padded,
)
# If force_attention is True, we always capture attention. Otherwise,
# it only happens for cudagraph_runtime_mode=FULL.
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
@@ -4431,6 +4541,7 @@ class GPUModelRunner(
max_query_len=max_query_len,
ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices,
for_cudagraph_capture=is_graph_capturing,
slot_mappings=slot_mappings_by_group,
)
with self.maybe_dummy_run_with_lora(
@@ -4499,6 +4610,7 @@ class GPUModelRunner(
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded,
slot_mapping=slot_mappings,
),
):
outputs = self.model(
@@ -4545,6 +4657,7 @@ class GPUModelRunner(
num_tokens,
use_cudagraphs=use_cudagraphs,
is_graph_capturing=is_graph_capturing,
slot_mappings=slot_mappings,
)
# We register layerwise NVTX hooks here after the first dynamo tracing is

View File

@@ -295,6 +295,7 @@ class UBatchWrapper:
self,
ubatch_slices,
attn_metadata,
slot_mapping,
input_ids,
positions,
inputs_embeds,
@@ -306,6 +307,9 @@ class UBatchWrapper:
) -> list[UbatchMetadata]:
# Create one forward context per ubatch
forward_contexts = []
# slot_mapping can be None, an empty dict (from create_forward_context
# converting None to {}), or a list of dicts (one per ubatch)
has_slot_mapping = slot_mapping and isinstance(slot_mapping, list)
for i, ubatch_slice in enumerate(ubatch_slices):
forward_contexts.append(
create_forward_context(
@@ -314,6 +318,7 @@ class UBatchWrapper:
dp_metadata=dp_metadata[i],
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping[i] if has_slot_mapping else None,
)
)
@@ -406,6 +411,7 @@ class UBatchWrapper:
return self.cudagraph_wrapper(*args, **kwargs)
attn_metadata = forward_context.attn_metadata
slot_mapping = forward_context.slot_mapping
num_tokens = (
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
) * 2
@@ -440,6 +446,7 @@ class UBatchWrapper:
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
slot_mapping=slot_mapping,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
@@ -462,6 +469,7 @@ class UBatchWrapper:
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
slot_mapping=slot_mapping,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,