[Performance] Split FlashAttn attention and cache update (#25954)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Luka Govedič <luka.govedic@gmail.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <luka.govedic@gmail.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
@@ -390,29 +390,43 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(
|
||||
self, query, key, value, self_kv_cache, attn_metadata, output=output
|
||||
kv_cache_dummy_dep = None
|
||||
if not self.attn_backend.forward_includes_kv_cache_update:
|
||||
kv_cache_dummy_dep = unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
else:
|
||||
kv_cache_dummy_dep = None
|
||||
if not self.attn_backend.forward_includes_kv_cache_update and (
|
||||
# torch can only dispatch custom op if a tensor is passed
|
||||
key is not None or value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query, key, value, output, self.layer_name
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
assert self.attn_backend.forward_includes_kv_cache_update, (
|
||||
"Split KV cache update not supported when output tensor not provided."
|
||||
)
|
||||
if self.use_direct_call:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(
|
||||
self, query, key, value, self_kv_cache, attn_metadata
|
||||
)
|
||||
return unified_attention(query, key, value, self.layer_name)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name
|
||||
@@ -802,6 +816,55 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def unified_kv_cache_update(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a dummy that is passed to unified_attention to signal a side effect and
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
)
|
||||
layer_slot_mapping = slot_mapping.get(layer_name)
|
||||
if layer_slot_mapping is not None:
|
||||
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
|
||||
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
|
||||
)
|
||||
attn_layer.impl.do_kv_cache_update(
|
||||
attn_layer,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
layer_slot_mapping,
|
||||
)
|
||||
|
||||
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
|
||||
|
||||
|
||||
def unified_kv_cache_update_fake(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(0, device=key.device, dtype=key.dtype)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_kv_cache_update",
|
||||
op_func=unified_kv_cache_update,
|
||||
fake_impl=unified_kv_cache_update_fake,
|
||||
mutates_args=[],
|
||||
)
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
@@ -811,7 +874,12 @@ def unified_attention_with_output(
|
||||
layer_name: str,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
|
||||
# that ensures torch.compile preserves ordering between KV cache update and
|
||||
# attention forward.
|
||||
del kv_cache_dummy_dep
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
|
||||
self.impl.forward(
|
||||
@@ -835,6 +903,7 @@ def unified_attention_with_output_fake(
|
||||
layer_name: str,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
@@ -189,6 +189,7 @@ class ForwardContext:
|
||||
# copy from vllm_config.compilation_config.static_forward_context
|
||||
no_compile_layers: dict[str, Any]
|
||||
attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
|
||||
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]
|
||||
"""
|
||||
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
|
||||
attention layer to its attention metadata
|
||||
@@ -266,6 +267,7 @@ def create_forward_context(
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: BatchDescriptor | None = None,
|
||||
ubatch_slices: UBatchSlices | None = None,
|
||||
slot_mapping: dict[str, torch.Tensor] | None = None,
|
||||
additional_kwargs: dict[str, Any] | None = None,
|
||||
skip_compiled: bool = False,
|
||||
):
|
||||
@@ -282,6 +284,7 @@ def create_forward_context(
|
||||
remaining_moe_layers=remaining_moe_layers,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mapping=slot_mapping or {},
|
||||
dp_metadata=dp_metadata,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
@@ -316,6 +319,7 @@ def set_forward_context(
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: BatchDescriptor | None = None,
|
||||
ubatch_slices: UBatchSlices | None = None,
|
||||
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
|
||||
skip_compiled: bool = False,
|
||||
):
|
||||
"""A context manager that stores the current forward context,
|
||||
@@ -374,6 +378,7 @@ def set_forward_context(
|
||||
cudagraph_runtime_mode,
|
||||
batch_descriptor,
|
||||
ubatch_slices,
|
||||
slot_mapping,
|
||||
additional_kwargs,
|
||||
skip_compiled,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
subclass_attention_backend_with_overrides,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
|
||||
@@ -72,6 +72,7 @@ def create_cross_attention_backend(
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "CrossAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
underlying_impl = underlying_attn_backend.get_impl_cls()
|
||||
|
||||
class CrossAttentionBuilder(underlying_builder): # type: ignore
|
||||
def build(
|
||||
@@ -106,18 +107,60 @@ def create_cross_attention_backend(
|
||||
)
|
||||
|
||||
# NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
|
||||
new_metadata.slot_mapping = _get_cross_slot_mapping(
|
||||
slot_mapping = _get_cross_slot_mapping(
|
||||
new_metadata.encoder_seq_lens_cpu,
|
||||
new_metadata.block_table_tensor,
|
||||
self.kv_cache_spec,
|
||||
self.device,
|
||||
)
|
||||
return super().build(common_prefix_len, new_metadata, fast_build)
|
||||
attn_metadata = super().build(common_prefix_len, new_metadata, fast_build)
|
||||
attn_metadata.slot_mapping = slot_mapping
|
||||
return attn_metadata
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
# NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by
|
||||
# `CrossAttentionBuilder` instead of the one computed by `BlockTable`
|
||||
# (gpu_model_runner)
|
||||
class CrossAttentionImpl(underlying_impl): # type: ignore[valid-type,misc]
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
not underlying_attn_backend.forward_includes_kv_cache_update
|
||||
and attn_metadata is not None
|
||||
):
|
||||
self.do_kv_cache_update(
|
||||
layer, key, value, kv_cache, attn_metadata.slot_mapping
|
||||
)
|
||||
|
||||
return super().forward(
|
||||
layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output,
|
||||
output_scale,
|
||||
output_block_scale,
|
||||
)
|
||||
|
||||
attn_backend = subclass_attention_backend_with_overrides(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=CrossAttentionBuilder,
|
||||
overrides={
|
||||
"get_builder_cls": lambda: CrossAttentionBuilder,
|
||||
"get_impl_cls": lambda: CrossAttentionImpl,
|
||||
"forward_includes_kv_cache_update": True,
|
||||
},
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
@@ -613,8 +613,9 @@ def weak_ref_tensor(tensor: Any) -> Any:
|
||||
Create a weak reference to a tensor.
|
||||
The new tensor will share the same data as the original tensor,
|
||||
but will not keep the original tensor alive.
|
||||
This ignores 0-size tensors as those don't allocate any memory.
|
||||
"""
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
if isinstance(tensor, torch.Tensor) and tensor.numel() > 0:
|
||||
return torch.ops._C.weak_ref_tensor(tensor)
|
||||
else:
|
||||
return tensor
|
||||
|
||||
@@ -53,6 +53,9 @@ class AttentionBackend(ABC):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto", "bfloat16"]
|
||||
|
||||
# Does attention's forward() include kv cache update?
|
||||
forward_includes_kv_cache_update: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(1)]
|
||||
|
||||
@@ -79,6 +79,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
return [16, 32, 64]
|
||||
return [MultipleOf(16)]
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN"
|
||||
@@ -652,32 +654,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
# queries are quantized in the attention layer
|
||||
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||
@@ -774,6 +750,49 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
return output
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is not None
|
||||
or key is None
|
||||
or value is None
|
||||
):
|
||||
return
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
def _forward_with_dcp(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
|
||||
@@ -159,6 +159,10 @@ class SpecDecodeBaseProposer:
|
||||
with_numpy=True,
|
||||
)
|
||||
|
||||
self._slot_mapping_buffer = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
# Determine allowed attention backends once during initialization.
|
||||
self.allowed_attn_types: tuple | None = None
|
||||
if current_platform.is_rocm():
|
||||
@@ -230,6 +234,24 @@ class SpecDecodeBaseProposer:
|
||||
positions = positions[0]
|
||||
self.positions[:num_tokens] = positions
|
||||
|
||||
def _get_slot_mapping(
|
||||
self,
|
||||
num_tokens: int,
|
||||
slot_mapping: torch.Tensor | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Return slot_mapping dict for EAGLE layers.
|
||||
|
||||
If slot_mapping is provided, copies it into the buffer first.
|
||||
"""
|
||||
if slot_mapping is not None:
|
||||
num_actual = slot_mapping.shape[0]
|
||||
self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
|
||||
if num_tokens > num_actual:
|
||||
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
|
||||
|
||||
view = self._slot_mapping_buffer[:num_tokens]
|
||||
return {name: view for name in self.attn_layer_names + self.indexer_layer_names}
|
||||
|
||||
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
|
||||
"""Initialize cudagraph dispatcher keys for eagle.
|
||||
|
||||
@@ -262,6 +284,9 @@ class SpecDecodeBaseProposer:
|
||||
sampling_metadata: SamplingMetadata,
|
||||
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
|
||||
num_rejected_tokens_gpu: torch.Tensor | None = None,
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = common_attn_metadata.batch_size()
|
||||
|
||||
@@ -358,6 +383,9 @@ class SpecDecodeBaseProposer:
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
num_input_tokens, common_attn_metadata.slot_mapping
|
||||
),
|
||||
):
|
||||
ret_hidden_states = self.model(**model_kwargs)
|
||||
if not self.model_returns_tuple():
|
||||
@@ -396,6 +424,7 @@ class SpecDecodeBaseProposer:
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
slot_mappings=slot_mappings,
|
||||
)
|
||||
# [batch_size, num_tree_tokens]
|
||||
return torch.cat(draft_token_ids_list, dim=1)
|
||||
@@ -553,6 +582,9 @@ class SpecDecodeBaseProposer:
|
||||
num_tokens=input_batch_size,
|
||||
num_tokens_across_dp=batch_size_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
input_batch_size, common_attn_metadata.slot_mapping
|
||||
),
|
||||
):
|
||||
ret_hidden_states = self.model(**model_kwargs)
|
||||
if not self.model_returns_tuple():
|
||||
@@ -760,6 +792,9 @@ class SpecDecodeBaseProposer:
|
||||
# [num_tokens, hidden_size]
|
||||
hidden_states: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None,
|
||||
) -> list[torch.Tensor]:
|
||||
tree_attn_metadata_builder = self.runner.attn_groups[0][
|
||||
0
|
||||
@@ -881,6 +916,9 @@ class SpecDecodeBaseProposer:
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
num_input_tokens, attn_metadata.slot_mapping
|
||||
),
|
||||
):
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
@@ -1212,6 +1250,7 @@ class SpecDecodeBaseProposer:
|
||||
num_tokens: int,
|
||||
use_cudagraphs: bool = True,
|
||||
is_graph_capturing: bool = False,
|
||||
slot_mappings: dict[str, torch.Tensor] | None = None,
|
||||
) -> None:
|
||||
# FIXME: when using tree-based specdec, adjust number of forward-passes
|
||||
# according to the depth of the tree.
|
||||
@@ -1233,12 +1272,23 @@ class SpecDecodeBaseProposer:
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
# Make sure to use EAGLE's own buffer during cudagraph capture.
|
||||
if (
|
||||
self.attn_layer_names
|
||||
and slot_mappings is not None
|
||||
and self.attn_layer_names[0] in slot_mappings
|
||||
):
|
||||
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
|
||||
else:
|
||||
slot_mapping_dict = slot_mappings or {}
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=slot_mapping_dict,
|
||||
):
|
||||
if self.supports_mm_inputs:
|
||||
input_ids = None
|
||||
|
||||
@@ -38,6 +38,9 @@ class MedusaProposer:
|
||||
self,
|
||||
target_hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None, # unused
|
||||
) -> torch.Tensor:
|
||||
# Generate blocks and compute logits
|
||||
blocks = self.model(target_hidden_states)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numba import get_num_threads, jit, njit, prange, set_num_threads
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@@ -132,6 +133,9 @@ class NgramProposer:
|
||||
sampled_token_ids: list[list[int]],
|
||||
num_tokens_no_spec: np.ndarray,
|
||||
token_ids_cpu: np.ndarray,
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None, # unused
|
||||
) -> list[list[int]]:
|
||||
# find which requests need ngram proposals
|
||||
valid_ngram_requests = []
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
@@ -33,6 +35,9 @@ class SuffixDecodingProposer:
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
sampled_token_ids: list[list[int]],
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None, # unused
|
||||
) -> list[list[int]]:
|
||||
"""
|
||||
Propose speculative tokens for each request in the input batch. Suffix Decoding
|
||||
|
||||
@@ -140,6 +140,18 @@ def init_kv_cache(
|
||||
return kv_caches
|
||||
|
||||
|
||||
def build_slot_mappings_by_layer(
|
||||
slot_mappings: torch.Tensor,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
slot_mappings_by_layer: dict[str, torch.Tensor] = {}
|
||||
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
slot_mapping = slot_mappings[i]
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
slot_mappings_by_layer[layer_name] = slot_mapping
|
||||
return slot_mappings_by_layer
|
||||
|
||||
|
||||
def build_attn_metadata(
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
num_reqs: int,
|
||||
|
||||
@@ -14,7 +14,10 @@ from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
@@ -88,7 +91,7 @@ class CudaGraphManager:
|
||||
positions = mrope_positions[:, :num_tokens]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:num_tokens]
|
||||
attn_metadata = prepare_inputs_to_capture(
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
input_buffers,
|
||||
@@ -98,6 +101,9 @@ class CudaGraphManager:
|
||||
kv_cache_config,
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, kv_cache_config
|
||||
)
|
||||
|
||||
# Warm up.
|
||||
with set_forward_context(
|
||||
@@ -106,6 +112,7 @@ class CudaGraphManager:
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings_by_layer,
|
||||
):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
@@ -125,6 +132,7 @@ class CudaGraphManager:
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings_by_layer,
|
||||
),
|
||||
torch.cuda.graph(graph, self.pool),
|
||||
):
|
||||
@@ -244,7 +252,7 @@ def prepare_inputs_to_capture(
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
max_model_len: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
) -> tuple[dict[str, Any], torch.Tensor]:
|
||||
num_tokens_per_req = num_tokens // num_reqs
|
||||
|
||||
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
|
||||
@@ -274,4 +282,4 @@ def prepare_inputs_to_capture(
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
return attn_metadata
|
||||
return attn_metadata, slot_mappings
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
get_kv_cache_spec,
|
||||
init_attn_backend,
|
||||
init_kv_cache,
|
||||
@@ -881,6 +882,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.uses_mrope:
|
||||
assert input_batch.mrope_positions is not None
|
||||
positions = input_batch.mrope_positions
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
input_batch.idx_mapping,
|
||||
input_batch.query_start_loc,
|
||||
input_batch.positions[: input_batch.num_tokens],
|
||||
)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
@@ -888,6 +897,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# TODO(woosuk): Support piecewise CUDA graph.
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings_by_layer,
|
||||
):
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
hidden_states = self.model(
|
||||
|
||||
@@ -314,6 +314,7 @@ class ExecuteModelState(NamedTuple):
|
||||
aux_hidden_states: list[torch.Tensor] | None
|
||||
ec_connector_output: ECConnectorOutput | None
|
||||
cudagraph_stats: CUDAGraphStat | None
|
||||
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None
|
||||
|
||||
|
||||
class GPUModelRunner(
|
||||
@@ -1595,6 +1596,7 @@ class GPUModelRunner(
|
||||
for_cudagraph_capture: bool = False,
|
||||
num_scheduled_tokens: dict[str, int] | None = None,
|
||||
cascade_attn_prefix_lens: list[list[int]] | None = None,
|
||||
slot_mappings: dict[int, torch.Tensor] | None = None,
|
||||
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
|
||||
"""
|
||||
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
|
||||
@@ -1628,7 +1630,7 @@ class GPUModelRunner(
|
||||
|
||||
kv_cache_groups = self.kv_cache_config.kv_cache_groups
|
||||
|
||||
def _get_block_table_and_slot_mapping(kv_cache_gid: int):
|
||||
def _get_block_table(kv_cache_gid: int):
|
||||
assert num_reqs_padded is not None and num_tokens_padded is not None
|
||||
kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec
|
||||
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
@@ -1637,24 +1639,19 @@ class GPUModelRunner(
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
slot_mapping = torch.zeros(
|
||||
(num_tokens_padded,),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
blk_table = self.input_batch.block_table[kv_cache_gid]
|
||||
blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
|
||||
slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
|
||||
slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
|
||||
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
|
||||
return blk_table_tensor
|
||||
|
||||
return blk_table_tensor, slot_mapping
|
||||
assert slot_mappings is not None
|
||||
block_table_gid_0 = _get_block_table(0)
|
||||
slot_mapping_gid_0 = slot_mappings[0]
|
||||
|
||||
block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)
|
||||
if self.model_config.enable_return_routed_experts:
|
||||
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
|
||||
cm_base = CommonAttentionMetadata(
|
||||
@@ -1779,9 +1776,8 @@ class GPUModelRunner(
|
||||
num_reqs_padded,
|
||||
)
|
||||
if kv_cache_gid > 0:
|
||||
cm.block_table_tensor, cm.slot_mapping = (
|
||||
_get_block_table_and_slot_mapping(kv_cache_gid)
|
||||
)
|
||||
cm.block_table_tensor = _get_block_table(kv_cache_gid)
|
||||
cm.slot_mapping = slot_mappings[kv_cache_gid]
|
||||
|
||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||
if isinstance(self.drafter, EagleProposer):
|
||||
@@ -3119,6 +3115,80 @@ class GPUModelRunner(
|
||||
pyt_hooks.register_hooks(self.model, self.model.__class__.__name__)
|
||||
self.layerwise_nvtx_hooks_registered = True
|
||||
|
||||
def _get_slot_mappings(
|
||||
self,
|
||||
num_tokens_padded: int,
|
||||
num_reqs_padded: int,
|
||||
num_tokens_unpadded: int,
|
||||
ubatch_slices: "UBatchSlices | None" = None,
|
||||
) -> tuple[
|
||||
dict[int, torch.Tensor] | None,
|
||||
dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
|
||||
]:
|
||||
"""
|
||||
Build slot mappings in both formats needed by the system.
|
||||
|
||||
Args:
|
||||
num_tokens_padded: Total number of tokens (padded)
|
||||
num_reqs_padded: Total number of requests (padded)
|
||||
num_tokens_unpadded: Actual number of tokens (unpadded)
|
||||
ubatch_slices: Optional ubatch slicing info for DBO
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
- slot_mappings_by_gid: dict[int, torch.Tensor] for attention metadata
|
||||
- slot_mappings_by_layer: dict[str, torch.Tensor] or list for ForwardContext
|
||||
"""
|
||||
if not (
|
||||
hasattr(self, "kv_cache_config")
|
||||
and self.kv_cache_config is not None
|
||||
and len(self.kv_cache_config.kv_cache_groups) > 0
|
||||
):
|
||||
return None, None
|
||||
|
||||
def _get_slot_mapping(kv_cache_gid: int):
|
||||
assert num_reqs_padded is not None and num_tokens_padded is not None
|
||||
kv_cache_spec = self.kv_cache_config.kv_cache_groups[
|
||||
kv_cache_gid
|
||||
].kv_cache_spec
|
||||
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
slot_mapping = torch.zeros(
|
||||
(num_tokens_padded,),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
blk_table = self.input_batch.block_table[kv_cache_gid]
|
||||
slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
|
||||
slot_mapping[num_tokens_unpadded:num_tokens_padded].fill_(-1)
|
||||
|
||||
return slot_mapping
|
||||
|
||||
slot_mappings_by_gid = {
|
||||
gid: _get_slot_mapping(gid)
|
||||
for gid, _ in enumerate(self.kv_cache_config.kv_cache_groups)
|
||||
}
|
||||
|
||||
slot_mappings_by_layer: dict[str, torch.Tensor] = {}
|
||||
for gid, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups):
|
||||
slot_mapping = slot_mappings_by_gid[gid]
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
slot_mappings_by_layer[layer_name] = slot_mapping
|
||||
|
||||
if ubatch_slices is not None:
|
||||
result: list[dict[str, torch.Tensor]] = []
|
||||
for ubatch in ubatch_slices:
|
||||
sliced_mappings: dict[str, torch.Tensor] = {}
|
||||
for layer_name, slot_mapping in slot_mappings_by_layer.items():
|
||||
sliced_mappings[layer_name] = slot_mapping[ubatch.token_slice]
|
||||
result.append(sliced_mappings)
|
||||
return slot_mappings_by_gid, result
|
||||
|
||||
return slot_mappings_by_gid, slot_mappings_by_layer
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -3248,6 +3318,17 @@ class GPUModelRunner(
|
||||
ubatch_slices_padded,
|
||||
)
|
||||
|
||||
# True if any attention backend handles KV cache update separately
|
||||
# from forward() (i.e., forward_includes_kv_cache_update=False). When true,
|
||||
# slot_mappings must use padded dimensions to match the key/value tensors.
|
||||
has_separate_kv_update = not all(
|
||||
all(
|
||||
g.backend.forward_includes_kv_cache_update
|
||||
for g in self.attn_groups[id]
|
||||
)
|
||||
for id, spec in enumerate(self.kv_cache_config.kv_cache_groups)
|
||||
if not isinstance(spec.kv_cache_spec, EncoderOnlyAttentionSpec)
|
||||
)
|
||||
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
|
||||
|
||||
if self.cache_config.mamba_cache_mode == "align":
|
||||
@@ -3265,6 +3346,17 @@ class GPUModelRunner(
|
||||
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
|
||||
|
||||
slot_mappings_by_group, slot_mappings = self._get_slot_mappings(
|
||||
num_tokens_padded=num_tokens_padded
|
||||
if pad_attn or has_separate_kv_update
|
||||
else num_tokens_unpadded,
|
||||
num_reqs_padded=(
|
||||
num_reqs_padded if pad_attn or has_separate_kv_update else num_reqs
|
||||
),
|
||||
num_tokens_unpadded=num_tokens_unpadded,
|
||||
ubatch_slices=ubatch_slices_padded,
|
||||
)
|
||||
|
||||
attn_metadata, spec_decode_common_attn_metadata = (
|
||||
self._build_attention_metadata(
|
||||
num_tokens=num_tokens_unpadded,
|
||||
@@ -3277,6 +3369,7 @@ class GPUModelRunner(
|
||||
use_spec_decode=use_spec_decode,
|
||||
num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
|
||||
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
|
||||
slot_mappings=slot_mappings_by_group,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -3317,6 +3410,7 @@ class GPUModelRunner(
|
||||
cudagraph_runtime_mode=cudagraph_mode,
|
||||
batch_descriptor=batch_desc,
|
||||
ubatch_slices=ubatch_slices_padded,
|
||||
slot_mapping=slot_mappings,
|
||||
skip_compiled=has_encoder_input,
|
||||
),
|
||||
record_function_or_nullcontext("gpu_model_runner: forward"),
|
||||
@@ -3399,6 +3493,7 @@ class GPUModelRunner(
|
||||
aux_hidden_states,
|
||||
ec_connector_output,
|
||||
cudagraph_stats,
|
||||
slot_mappings,
|
||||
)
|
||||
self.kv_connector_output = kv_connector_output
|
||||
return None
|
||||
@@ -3435,6 +3530,7 @@ class GPUModelRunner(
|
||||
aux_hidden_states,
|
||||
ec_connector_output,
|
||||
cudagraph_stats,
|
||||
slot_mappings,
|
||||
) = self.execute_model_state
|
||||
# Clear ephemeral state.
|
||||
self.execute_model_state = None
|
||||
@@ -3468,6 +3564,7 @@ class GPUModelRunner(
|
||||
aux_hidden_states,
|
||||
spec_decode_metadata,
|
||||
spec_decode_common_attn_metadata,
|
||||
slot_mappings,
|
||||
)
|
||||
self._copy_draft_token_ids_to_cpu(scheduler_output)
|
||||
|
||||
@@ -3676,6 +3773,7 @@ class GPUModelRunner(
|
||||
aux_hidden_states: list[torch.Tensor] | None,
|
||||
spec_decode_metadata: SpecDecodeMetadata | None,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
|
||||
) -> list[list[int]] | torch.Tensor:
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
spec_config = self.speculative_config
|
||||
@@ -3687,11 +3785,14 @@ class GPUModelRunner(
|
||||
sampled_token_ids,
|
||||
self.input_batch.num_tokens_no_spec,
|
||||
self.input_batch.token_ids_cpu,
|
||||
slot_mappings=slot_mappings,
|
||||
)
|
||||
elif spec_config.method == "suffix":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
assert isinstance(self.drafter, SuffixDecodingProposer)
|
||||
draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids)
|
||||
draft_token_ids = self.drafter.propose(
|
||||
self.input_batch, sampled_token_ids, slot_mappings=slot_mappings
|
||||
)
|
||||
elif spec_config.method == "medusa":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
assert isinstance(self.drafter, MedusaProposer)
|
||||
@@ -3716,6 +3817,7 @@ class GPUModelRunner(
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_hidden_states=hidden_states,
|
||||
sampling_metadata=sampling_metadata,
|
||||
slot_mappings=slot_mappings,
|
||||
)
|
||||
elif spec_config.use_eagle() or spec_config.uses_draft_model():
|
||||
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
|
||||
@@ -3826,6 +3928,7 @@ class GPUModelRunner(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
mm_embed_inputs=mm_embed_inputs,
|
||||
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
|
||||
slot_mappings=slot_mappings,
|
||||
)
|
||||
|
||||
return draft_token_ids
|
||||
@@ -4406,6 +4509,13 @@ class GPUModelRunner(
|
||||
|
||||
attn_metadata: PerLayerAttnMetadata | None = None
|
||||
|
||||
slot_mappings_by_group, slot_mappings = self._get_slot_mappings(
|
||||
num_tokens_padded=num_tokens,
|
||||
num_reqs_padded=num_reqs_padded,
|
||||
num_tokens_unpadded=num_tokens_unpadded,
|
||||
ubatch_slices=ubatch_slices_padded,
|
||||
)
|
||||
|
||||
# If force_attention is True, we always capture attention. Otherwise,
|
||||
# it only happens for cudagraph_runtime_mode=FULL.
|
||||
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
@@ -4431,6 +4541,7 @@ class GPUModelRunner(
|
||||
max_query_len=max_query_len,
|
||||
ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices,
|
||||
for_cudagraph_capture=is_graph_capturing,
|
||||
slot_mappings=slot_mappings_by_group,
|
||||
)
|
||||
|
||||
with self.maybe_dummy_run_with_lora(
|
||||
@@ -4499,6 +4610,7 @@ class GPUModelRunner(
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_desc,
|
||||
ubatch_slices=ubatch_slices_padded,
|
||||
slot_mapping=slot_mappings,
|
||||
),
|
||||
):
|
||||
outputs = self.model(
|
||||
@@ -4545,6 +4657,7 @@ class GPUModelRunner(
|
||||
num_tokens,
|
||||
use_cudagraphs=use_cudagraphs,
|
||||
is_graph_capturing=is_graph_capturing,
|
||||
slot_mappings=slot_mappings,
|
||||
)
|
||||
|
||||
# We register layerwise NVTX hooks here after the first dynamo tracing is
|
||||
|
||||
@@ -295,6 +295,7 @@ class UBatchWrapper:
|
||||
self,
|
||||
ubatch_slices,
|
||||
attn_metadata,
|
||||
slot_mapping,
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
@@ -306,6 +307,9 @@ class UBatchWrapper:
|
||||
) -> list[UbatchMetadata]:
|
||||
# Create one forward context per ubatch
|
||||
forward_contexts = []
|
||||
# slot_mapping can be None, an empty dict (from create_forward_context
|
||||
# converting None to {}), or a list of dicts (one per ubatch)
|
||||
has_slot_mapping = slot_mapping and isinstance(slot_mapping, list)
|
||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||
forward_contexts.append(
|
||||
create_forward_context(
|
||||
@@ -314,6 +318,7 @@ class UBatchWrapper:
|
||||
dp_metadata=dp_metadata[i],
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=slot_mapping[i] if has_slot_mapping else None,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -406,6 +411,7 @@ class UBatchWrapper:
|
||||
return self.cudagraph_wrapper(*args, **kwargs)
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
num_tokens = (
|
||||
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
|
||||
) * 2
|
||||
@@ -440,6 +446,7 @@ class UBatchWrapper:
|
||||
ubatch_metadata = self._make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mapping=slot_mapping,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
@@ -462,6 +469,7 @@ class UBatchWrapper:
|
||||
ubatch_metadata = self._make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mapping=slot_mapping,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
|
||||
Reference in New Issue
Block a user