[Frontend] support matryoshka representation / support embedding API dimensions (#16331)
This commit is contained in:
@@ -921,6 +921,11 @@ class LLM:
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
elif isinstance(pooling_params, PoolingParams):
|
||||
pooling_params.verify(self.llm_engine.model_config)
|
||||
else:
|
||||
for pooling_param in pooling_params:
|
||||
pooling_param.verify(self.llm_engine.model_config)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
@@ -939,6 +944,8 @@ class LLM:
|
||||
/,
|
||||
*,
|
||||
use_tqdm: bool = True,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> list[EmbeddingRequestOutput]:
|
||||
@@ -953,6 +960,8 @@ class LLM:
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each prompts.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
@@ -968,6 +977,7 @@ class LLM:
|
||||
|
||||
items = self.encode(prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
|
||||
@@ -1006,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
# doc: end-embedding-extra-params
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
return PoolingParams(dimensions=self.dimensions,
|
||||
additional_data=self.additional_data)
|
||||
|
||||
|
||||
class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
@@ -1068,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
return data
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
return PoolingParams(dimensions=self.dimensions,
|
||||
additional_data=self.additional_data)
|
||||
|
||||
|
||||
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
||||
|
||||
@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
return error_check_ret
|
||||
|
||||
encoding_format = request.encoding_format
|
||||
if request.dimensions is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
model_name = self._get_model_name(request.model)
|
||||
request_id = f"embd-{self._base_request_id(raw_request)}"
|
||||
@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
try:
|
||||
pooling_params.verify(self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user