[platform] Support additional forward context for OOT (#31674)

Signed-off-by: zzzzwwjj <1183291235@qq.com>
Signed-off-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
zzzzwwjj
2026-01-05 18:25:13 +08:00
committed by GitHub
parent b471aad41f
commit caaa482aca
2 changed files with 25 additions and 3 deletions

View File

@@ -4,7 +4,7 @@
import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, NamedTuple
import torch
@@ -13,6 +13,7 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ubatch_utils import UBatchSlices
@@ -206,6 +207,8 @@ class ForwardContext:
ubatch_slices: UBatchSlices | None = None
additional_kwargs: dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
@@ -236,6 +239,7 @@ def create_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None,
additional_kwargs: dict[str, Any] | None = None,
):
return ForwardContext(
no_compile_layers=vllm_config.compilation_config.static_forward_context,
@@ -245,6 +249,7 @@ def create_forward_context(
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices,
additional_kwargs=additional_kwargs or {},
)
@@ -310,6 +315,17 @@ def set_forward_context(
if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)
additional_kwargs = current_platform.set_additional_forward_context(
attn_metadata=attn_metadata,
vllm_config=vllm_config,
virtual_engine=virtual_engine,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices,
)
forward_context = create_forward_context(
attn_metadata,
vllm_config,
@@ -318,6 +334,7 @@ def set_forward_context(
cudagraph_runtime_mode,
batch_descriptor,
ubatch_slices,
additional_kwargs,
)
try:
@@ -330,8 +347,6 @@ def set_forward_context(
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
from vllm.platforms import current_platform
synchronize = current_platform.synchronize
if synchronize is not None:
synchronize()

View File

@@ -693,6 +693,13 @@ class Platform:
"""
return max_model_len
@classmethod
def set_additional_forward_context(cls, *args, **kwargs) -> dict[str, Any]:
"""
Set some additional forward context for the current platform if needs.
"""
return {}
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED