Support token_type_ids in V1 with less code changes (#21985)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
committed by
GitHub
parent
9c97a1c349
commit
39052dbca8
@@ -28,11 +28,15 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages,
|
||||
resolve_chat_template_content_format)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
|
||||
ScoreMultiModalParam,
|
||||
_cosine_similarity,
|
||||
_validate_score_input_lens,
|
||||
compress_token_type_ids,
|
||||
get_score_prompt)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.utils import (_validate_truncation_size,
|
||||
log_non_default_args)
|
||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||
@@ -1329,6 +1333,7 @@ class LLM:
|
||||
|
||||
model_config = self.llm_engine.model_config
|
||||
pooling_params.verify("score", model_config)
|
||||
pooling_params_list = list[PoolingParams]()
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
|
||||
@@ -1339,38 +1344,31 @@ class LLM:
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
||||
|
||||
if model_config.is_multimodal_model:
|
||||
for q, d in input_pairs:
|
||||
_, engine_prompt = get_score_prompt(
|
||||
model_config=model_config,
|
||||
data_1=q,
|
||||
data_2=d,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
model_config = self.llm_engine.model_config
|
||||
|
||||
parsed_prompts.append(engine_prompt)
|
||||
else:
|
||||
for q, t in input_pairs:
|
||||
if model_config.use_pad_token:
|
||||
# cross_encoder models defaults to using pad_token.
|
||||
prompt_inputs = tokenizer(
|
||||
text=q, # type: ignore[arg-type]
|
||||
text_pair=t, # type: ignore[arg-type]
|
||||
**tokenization_kwargs)
|
||||
else:
|
||||
# `llm as reranker` models defaults to not using pad_token.
|
||||
prompt_inputs = tokenizer(
|
||||
text=q + t, # type: ignore[operator]
|
||||
**tokenization_kwargs)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["input_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
parsed_prompts.append(engine_prompt)
|
||||
for q, d in input_pairs:
|
||||
_, engine_prompt = get_score_prompt(
|
||||
model_config=model_config,
|
||||
data_1=q,
|
||||
data_2=d,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop(
|
||||
"token_type_ids", None)):
|
||||
params = pooling_params.clone()
|
||||
compressed = compress_token_type_ids(token_type_ids)
|
||||
params.extra_kwargs = {"compressed_token_type_ids": compressed}
|
||||
pooling_params_list.append(params)
|
||||
else:
|
||||
pooling_params_list.append(pooling_params)
|
||||
|
||||
parsed_prompts.append(engine_prompt)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
params=pooling_params,
|
||||
params=pooling_params_list,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user