diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 1856d7f70..d1223ad83 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -4,7 +4,7 @@ import time from collections import defaultdict from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, NamedTuple import torch @@ -13,6 +13,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ubatch_utils import UBatchSlices @@ -206,6 +207,8 @@ class ForwardContext: ubatch_slices: UBatchSlices | None = None + additional_kwargs: dict[str, Any] = field(default_factory=dict) + def __post_init__(self): assert self.cudagraph_runtime_mode.valid_runtime_modes(), ( f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" @@ -236,6 +239,7 @@ def create_forward_context( cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, + additional_kwargs: dict[str, Any] | None = None, ): return ForwardContext( no_compile_layers=vllm_config.compilation_config.static_forward_context, @@ -245,6 +249,7 @@ def create_forward_context( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, + additional_kwargs=additional_kwargs or {}, ) @@ -310,6 +315,17 @@ def set_forward_context( if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None: batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens) + additional_kwargs = current_platform.set_additional_forward_context( + attn_metadata=attn_metadata, + vllm_config=vllm_config, + virtual_engine=virtual_engine, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices, + ) + forward_context = create_forward_context( attn_metadata, vllm_config, @@ -318,6 +334,7 @@ def set_forward_context( cudagraph_runtime_mode, batch_descriptor, ubatch_slices, + additional_kwargs, ) try: @@ -330,8 +347,6 @@ def set_forward_context( # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch - from vllm.platforms import current_platform - synchronize = current_platform.synchronize if synchronize is not None: synchronize() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 833b66d5b..6f21f47b9 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -693,6 +693,13 @@ class Platform: """ return max_model_len + @classmethod + def set_additional_forward_context(cls, *args, **kwargs) -> dict[str, Any]: + """ + Set some additional forward context for the current platform if needs. + """ + return {} + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED