[core] allow callable in collective_rpc (#12151)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-01-17 20:47:01 +08:00
committed by GitHub
parent d4e6194570
commit 87a0c076af
13 changed files with 147 additions and 50 deletions

View File

@@ -5,10 +5,10 @@ from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
List, Mapping, NamedTuple, Optional)
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload
from typing import Set, Tuple, Type, Union, cast, overload
import torch
from typing_extensions import TypeVar, deprecated
@@ -1816,6 +1816,17 @@ class LLMEngine:
def stop_profile(self) -> None:
self.model_executor.stop_profile()
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
See LLM.collective_rpc for more details.
"""
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()