[platform] Support additional forward context for OOT (#31674)
Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
@@ -13,6 +13,7 @@ import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlices
|
||||
|
||||
@@ -206,6 +207,8 @@ class ForwardContext:
|
||||
|
||||
ubatch_slices: UBatchSlices | None = None
|
||||
|
||||
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
|
||||
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
|
||||
@@ -236,6 +239,7 @@ def create_forward_context(
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: BatchDescriptor | None = None,
|
||||
ubatch_slices: UBatchSlices | None = None,
|
||||
additional_kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
return ForwardContext(
|
||||
no_compile_layers=vllm_config.compilation_config.static_forward_context,
|
||||
@@ -245,6 +249,7 @@ def create_forward_context(
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
ubatch_slices=ubatch_slices,
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
@@ -310,6 +315,17 @@ def set_forward_context(
|
||||
if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
|
||||
batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)
|
||||
|
||||
additional_kwargs = current_platform.set_additional_forward_context(
|
||||
attn_metadata=attn_metadata,
|
||||
vllm_config=vllm_config,
|
||||
virtual_engine=virtual_engine,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
ubatch_slices=ubatch_slices,
|
||||
)
|
||||
|
||||
forward_context = create_forward_context(
|
||||
attn_metadata,
|
||||
vllm_config,
|
||||
@@ -318,6 +334,7 @@ def set_forward_context(
|
||||
cudagraph_runtime_mode,
|
||||
batch_descriptor,
|
||||
ubatch_slices,
|
||||
additional_kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -330,8 +347,6 @@ def set_forward_context(
|
||||
# we use synchronous scheduling right now,
|
||||
# adding a sync point here should not affect
|
||||
# scheduling of the next batch
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
synchronize = current_platform.synchronize
|
||||
if synchronize is not None:
|
||||
synchronize()
|
||||
|
||||
@@ -693,6 +693,13 @@ class Platform:
|
||||
"""
|
||||
return max_model_len
|
||||
|
||||
@classmethod
|
||||
def set_additional_forward_context(cls, *args, **kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Set some additional forward context for the current platform if needs.
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
Reference in New Issue
Block a user