Simplify TokenizerGroup (#16790)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -55,7 +55,7 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||
TokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import (Counter, Device, deprecate_kwargs,
|
||||
@@ -66,7 +66,6 @@ from vllm.worker.model_runner_base import InputProcessingError
|
||||
logger = init_logger(__name__)
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
||||
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
@@ -205,7 +204,7 @@ class LLMEngine:
|
||||
|
||||
return outputs_
|
||||
|
||||
tokenizer: Optional[BaseTokenizerGroup]
|
||||
tokenizer: Optional[TokenizerGroup]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -321,11 +320,6 @@ class LLMEngine:
|
||||
self.parallel_config.disable_custom_all_reduce,
|
||||
})
|
||||
|
||||
if self.tokenizer:
|
||||
# Ping the tokenizer to ensure liveness if it runs in a
|
||||
# different process.
|
||||
self.tokenizer.ping()
|
||||
|
||||
self.cached_scheduler_outputs = [
|
||||
SchedulerOutputState()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
@@ -537,21 +531,12 @@ class LLMEngine:
|
||||
if model_executor := getattr(self, "model_executor", None):
|
||||
model_executor.shutdown()
|
||||
|
||||
def get_tokenizer_group(
|
||||
self,
|
||||
group_type: Type[_G] = BaseTokenizerGroup,
|
||||
) -> _G:
|
||||
tokenizer_group = self.tokenizer
|
||||
|
||||
if tokenizer_group is None:
|
||||
def get_tokenizer_group(self) -> TokenizerGroup:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
if not isinstance(tokenizer_group, group_type):
|
||||
raise TypeError("Invalid type of tokenizer group. "
|
||||
f"Expected type: {group_type}, but "
|
||||
f"found type: {type(tokenizer_group)}")
|
||||
|
||||
return tokenizer_group
|
||||
return self.tokenizer
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
@@ -559,11 +544,10 @@ class LLMEngine:
|
||||
) -> AnyTokenizer:
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
|
||||
|
||||
def _init_tokenizer(self) -> BaseTokenizerGroup:
|
||||
def _init_tokenizer(self) -> TokenizerGroup:
|
||||
return init_tokenizer_from_configs(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
parallel_config=self.parallel_config,
|
||||
lora_config=self.lora_config)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
@@ -1952,8 +1936,6 @@ class LLMEngine:
|
||||
return self.model_executor.is_sleeping
|
||||
|
||||
def check_health(self) -> None:
|
||||
if self.tokenizer:
|
||||
self.tokenizer.check_health()
|
||||
self.model_executor.check_health()
|
||||
|
||||
def is_tracing_enabled(self) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user