[V1] Support LLM.apply_model (#18465)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -13,6 +13,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, Union, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -55,6 +56,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
from vllm.worker.model_runner_base import InputProcessingError
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
@@ -1817,13 +1819,16 @@ class LLMEngine:
|
||||
return sampling_params
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
method: Union[str, Callable[[WorkerBase], _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.model_executor.collective_rpc(method, timeout, args,
|
||||
kwargs)
|
||||
|
||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||
return self.collective_rpc("apply_model", args=(func, ))
|
||||
|
||||
|
||||
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
||||
|
||||
Reference in New Issue
Block a user