[Frontend] Support using chat template as custom score template for reranking models (#30550)
Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com> Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -1280,6 +1280,7 @@ class LLM:
|
||||
pooling_params: PoolingParams | None = None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
score_template: str | None = None,
|
||||
) -> list[ScoringRequestOutput]:
|
||||
model_config = self.model_config
|
||||
|
||||
@@ -1313,6 +1314,7 @@ class LLM:
|
||||
data_2=d,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
score_template=score_template,
|
||||
)
|
||||
|
||||
if token_type_ids := engine_prompt.pop("token_type_ids", None):
|
||||
@@ -1347,6 +1349,7 @@ class LLM:
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
pooling_params: PoolingParams | None = None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
chat_template: str | None = None,
|
||||
) -> list[ScoringRequestOutput]:
|
||||
"""Generate similarity scores for all pairs `<text,text_pair>` or
|
||||
`<multi-modal data, multi-modal data pair>`.
|
||||
@@ -1379,6 +1382,8 @@ class LLM:
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
chat_template: The chat template to use for the scoring. If None, we
|
||||
use the model's default chat template.
|
||||
Returns:
|
||||
A list of `ScoringRequestOutput` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
@@ -1406,6 +1411,11 @@ class LLM:
|
||||
):
|
||||
raise ValueError("Score API is only enabled for num_labels == 1.")
|
||||
|
||||
if not model_config.is_cross_encoder and chat_template is not None:
|
||||
raise ValueError(
|
||||
"chat_template is only supported for cross-encoder models."
|
||||
)
|
||||
|
||||
# the tokenizer for models such as
|
||||
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
|
||||
# lists of tokens to the `text` and `text_pair` kwargs
|
||||
@@ -1475,6 +1485,7 @@ class LLM:
|
||||
use_tqdm,
|
||||
pooling_params,
|
||||
lora_request,
|
||||
score_template=chat_template,
|
||||
)
|
||||
else:
|
||||
return self._embedding_score(
|
||||
|
||||
Reference in New Issue
Block a user