[Core] Remove tokenizer group in vLLM (#24078)
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
@@ -49,9 +49,8 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
TokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
|
||||
@@ -186,7 +185,7 @@ class LLMEngine:
|
||||
|
||||
return outputs_
|
||||
|
||||
tokenizer: Optional[TokenizerGroup]
|
||||
tokenizer: Optional[AnyTokenizer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -233,18 +232,9 @@ class LLMEngine:
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
self.detokenizer = None
|
||||
tokenizer_group = None
|
||||
else:
|
||||
self.tokenizer = self._init_tokenizer()
|
||||
self.detokenizer = Detokenizer(self.tokenizer)
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
|
||||
# Ensure that the function doesn't contain a reference to self,
|
||||
# to avoid engine GC issues
|
||||
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
|
||||
assert tokenizer_group, ("tokenizer_group cannot be None, "
|
||||
"make sure skip_tokenizer_init is False")
|
||||
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
self.seq_counter = Counter()
|
||||
self.generation_config_fields = (
|
||||
@@ -389,10 +379,8 @@ class LLMEngine:
|
||||
self.detokenizer,
|
||||
self.scheduler,
|
||||
self.seq_counter,
|
||||
get_tokenizer_for_seq,
|
||||
stop_checker=StopChecker(
|
||||
self.scheduler_config.max_model_len,
|
||||
get_tokenizer_for_seq,
|
||||
self.reasoner if self.decoding_config.reasoning_backend
|
||||
and self.tokenizer else None,
|
||||
),
|
||||
@@ -521,24 +509,15 @@ class LLMEngine:
|
||||
if model_executor := getattr(self, "model_executor", None):
|
||||
model_executor.shutdown()
|
||||
|
||||
def get_tokenizer_group(self) -> TokenizerGroup:
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
return self.tokenizer
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
|
||||
|
||||
def _init_tokenizer(self) -> TokenizerGroup:
|
||||
return init_tokenizer_from_configs(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
lora_config=self.lora_config)
|
||||
def _init_tokenizer(self) -> AnyTokenizer:
|
||||
return init_tokenizer_from_configs(model_config=self.model_config)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
@@ -574,11 +553,11 @@ class LLMEngine:
|
||||
)
|
||||
return None
|
||||
|
||||
self._validate_model_inputs(processed_inputs, lora_request)
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seq_id = next(self.seq_counter)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id()
|
||||
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
|
||||
@@ -700,7 +679,6 @@ class LLMEngine:
|
||||
processed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
self._add_processed_request(
|
||||
@@ -1739,29 +1717,22 @@ class LLMEngine:
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
|
||||
metrics.model_execute_time)
|
||||
|
||||
def _validate_model_inputs(self, inputs: ProcessorInputs,
|
||||
lora_request: Optional[LoRARequest]):
|
||||
def _validate_model_inputs(self, inputs: ProcessorInputs):
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
|
||||
|
||||
if encoder_inputs is not None:
|
||||
self._validate_model_input(encoder_inputs,
|
||||
lora_request,
|
||||
prompt_type="encoder")
|
||||
self._validate_model_input(encoder_inputs, prompt_type="encoder")
|
||||
|
||||
self._validate_model_input(decoder_inputs,
|
||||
lora_request,
|
||||
prompt_type="decoder")
|
||||
self._validate_model_input(decoder_inputs, prompt_type="decoder")
|
||||
|
||||
def _validate_model_input(
|
||||
self,
|
||||
prompt_inputs: SingletonInputs,
|
||||
lora_request: Optional[LoRARequest],
|
||||
*,
|
||||
prompt_type: Literal["encoder", "decoder"],
|
||||
):
|
||||
model_config = self.model_config
|
||||
tokenizer = (None if self.tokenizer is None else
|
||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
prompt_ids = prompt_inputs.get("prompt_token_ids", [])
|
||||
if not prompt_ids:
|
||||
@@ -1822,7 +1793,7 @@ class LLMEngine:
|
||||
logits_processors = []
|
||||
|
||||
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
|
||||
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processors = get_openai_logits_processors(
|
||||
logit_bias=sampling_params.logit_bias,
|
||||
@@ -1835,7 +1806,7 @@ class LLMEngine:
|
||||
sampling_params.allowed_token_ids = None
|
||||
|
||||
if len(sampling_params.bad_words) > 0:
|
||||
tokenizer = self.get_tokenizer(lora_request)
|
||||
tokenizer = self.get_tokenizer()
|
||||
processors = get_bad_words_logits_processors(
|
||||
bad_words=sampling_params.bad_words, tokenizer=tokenizer)
|
||||
logits_processors.extend(processors)
|
||||
|
||||
Reference in New Issue
Block a user