[Frontend] support matryoshka representation / support embedding API dimensions (#16331)

This commit is contained in:
wang.yuqi
2025-04-12 14:23:10 +08:00
committed by GitHub
parent e92d7085bf
commit fbf722c6e6
11 changed files with 253 additions and 22 deletions

View File

@@ -583,6 +583,15 @@ class ModelConfig:
if getattr(user_config, k) is None:
setattr(user_config, k, v)
if self.is_matryoshka:
if user_config.normalize is None:
user_config.normalize = True
elif not user_config.normalize:
raise ValueError(
"`normalize` must be enabled (set to True) "
"for models that are compatible with "
"Matryoshka Representation.")
return user_config
return None

View File

@@ -921,6 +921,11 @@ class LLM:
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
elif isinstance(pooling_params, PoolingParams):
pooling_params.verify(self.llm_engine.model_config)
else:
for pooling_param in pooling_params:
pooling_param.verify(self.llm_engine.model_config)
self._validate_and_add_requests(
prompts=parsed_prompts,
@@ -939,6 +944,8 @@ class LLM:
/,
*,
use_tqdm: bool = True,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[EmbeddingRequestOutput]:
@@ -953,6 +960,8 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
@@ -968,6 +977,7 @@ class LLM:
items = self.encode(prompts,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)

View File

@@ -1006,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# doc: end-embedding-extra-params
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
class EmbeddingChatRequest(OpenAIBaseModel):
@@ -1068,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return data
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]

View File

@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
return error_check_ret
encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = self._get_model_name(request.model)
request_id = f"embd-{self._base_request_id(raw_request)}"
@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
"greater than max_model_len."
" Please, select a smaller truncation size.")
pooling_params = request.to_pooling_params()
try:
pooling_params.verify(self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
try:
(
lora_request,
@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"

View File

@@ -97,7 +97,7 @@ class SimplePooler(nn.Module):
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data)
pooled_data = self.head(pooled_data, pooling_metadata)
pooled_outputs = [self.build_output(data) for data in pooled_data]
return PoolerOutput(outputs=pooled_outputs)
@@ -217,14 +217,28 @@ class PoolerHead(nn.Module):
self.normalize = normalize
self.softmax = softmax
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor]):
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata):
dimensions_list = [
pooling_param.dimensions
for _, pooling_param in pooling_metadata.seq_groups
]
if any(d is not None for d in dimensions_list):
# change the output dimension
assert len(pooled_data) == len(dimensions_list)
pooled_data = [
vecs if d is None else vecs[..., :d]
for vecs, d in zip(pooled_data, dimensions_list)
]
if self.normalize:
if isinstance(pooled_data, list):
pooled_data = [
F.normalize(data, p=2, dim=1) for data in pooled_data
F.normalize(data, p=2, dim=-1) for data in pooled_data
]
else:
pooled_data = F.normalize(pooled_data, p=2, dim=1)
pooled_data = F.normalize(pooled_data, p=2, dim=-1)
if self.softmax:
if isinstance(pooled_data, list):

View File

@@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
import msgspec
if TYPE_CHECKING:
from vllm.config import ModelConfig
class PoolingParams(
msgspec.Struct,
@@ -12,14 +15,30 @@ class PoolingParams(
"""API parameters for pooling models. This is currently a placeholder.
Attributes:
dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation.
additional_data: Any additional data needed for pooling.
"""
dimensions: Optional[int] = None
additional_data: Optional[Any] = None
def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance."""
return PoolingParams(additional_data=self.additional_data)
return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
def verify(self, model_config: "ModelConfig") -> None:
if self.dimensions is not None:
if not model_config.is_matryoshka:
raise ValueError(
f'Model "{model_config.served_model_name}" does not '
f'support matryoshka representation, '
f'changing output dimensions will lead to poor results.')
if self.dimensions < 1:
raise ValueError("Dimensions must be greater than 0")
def __repr__(self) -> str:
return (f"PoolingParams("
f"dimensions={self.dimensions}, "
f"additional_metadata={self.additional_data})")