Fix wrong truncate_prompt_tokens type hint (#22761)
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com> Signed-off-by: Gabriel Marinho <104592062+gmarinho2@users.noreply.github.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@@ -51,7 +51,7 @@ from vllm.tasks import PoolingTask
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, Device, is_list_of
|
||||
from vllm.utils import Counter, Device, as_iter, is_list_of
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -364,14 +364,6 @@ class LLM:
|
||||
# Use default sampling params.
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
truncate_prompt_tokens = None
|
||||
if isinstance(sampling_params, SamplingParams):
|
||||
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
|
||||
|
||||
_validate_truncation_size(model_config.max_model_len,
|
||||
truncate_prompt_tokens, tokenization_kwargs)
|
||||
|
||||
# Add any modality specific loras to the corresponding prompts
|
||||
lora_request = self._get_modality_specific_lora_reqs(
|
||||
prompts, lora_request)
|
||||
@@ -381,7 +373,6 @@ class LLM:
|
||||
params=sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
@@ -871,6 +862,8 @@ class LLM:
|
||||
If `False`, no progress bar is created.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
pooling_task: Override the pooling task to use.
|
||||
tokenization_kwargs: overrides tokenization_kwargs set in
|
||||
pooling_params
|
||||
|
||||
Returns:
|
||||
A list of `PoolingRequestOutput` objects containing the
|
||||
@@ -916,24 +909,17 @@ class LLM:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
if isinstance(pooling_params, PoolingParams):
|
||||
pooling_params.verify(pooling_task, model_config)
|
||||
else:
|
||||
for pooling_param in pooling_params:
|
||||
pooling_param.verify(pooling_task, model_config)
|
||||
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = dict[str, Any]()
|
||||
_validate_truncation_size(model_config.max_model_len,
|
||||
truncate_prompt_tokens,
|
||||
tokenization_kwargs)
|
||||
for param in as_iter(pooling_params):
|
||||
param.verify(pooling_task, model_config)
|
||||
# for backwards compatibility
|
||||
if truncate_prompt_tokens is not None:
|
||||
param.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=prompts,
|
||||
params=pooling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
@@ -1385,7 +1371,6 @@ class LLM:
|
||||
*,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
priority: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
if isinstance(prompts, (str, dict)):
|
||||
@@ -1412,7 +1397,17 @@ class LLM:
|
||||
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
|
||||
it = tqdm_func(it, desc="Adding requests")
|
||||
|
||||
model_config = self.llm_engine.model_config
|
||||
|
||||
for i, prompt in enumerate(it):
|
||||
|
||||
param = params[i] if isinstance(params, Sequence) else params
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(model_config.max_model_len,
|
||||
param.truncate_prompt_tokens,
|
||||
tokenization_kwargs)
|
||||
|
||||
self._add_request(
|
||||
prompt,
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
|
||||
Reference in New Issue
Block a user