[Misc] Rename embedding classes to pooling (#10801)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-01 14:36:51 +08:00
committed by GitHub
parent f877a7d12a
commit d2f058e76c
25 changed files with 166 additions and 123 deletions

View File

@@ -26,7 +26,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions)
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
@@ -679,7 +679,7 @@ class LLM:
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
@@ -691,7 +691,7 @@ class LLM:
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
@@ -704,7 +704,7 @@ class LLM:
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
@@ -717,7 +717,7 @@ class LLM:
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
@@ -728,7 +728,7 @@ class LLM:
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload
@@ -741,7 +741,7 @@ class LLM:
Sequence[PoolingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@deprecate_kwargs(
@@ -759,7 +759,7 @@ class LLM:
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
"""Generates the completions for the input prompts.
This class automatically batches the given prompts, considering
@@ -778,7 +778,7 @@ class LLM:
generation, if any.
Returns:
A list of ``EmbeddingRequestOutput`` objects containing the
A list of ``PoolingRequestOutput`` objects containing the
generated embeddings in the same order as the input prompts.
Note:
@@ -821,7 +821,7 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)
PoolingRequestOutput)
def score(
self,
@@ -832,7 +832,7 @@ class LLM:
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
"""Generates similarity scores for all pairs <text,text_pair>.
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
@@ -854,7 +854,7 @@ class LLM:
generation, if any.
Returns:
A list of ``EmbeddingRequestOutput`` objects containing the
A list of ``PoolingRequestOutput`` objects containing the
generated scores in the same order as the input prompts.
"""
task = self.llm_engine.model_config.task
@@ -943,7 +943,7 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)
PoolingRequestOutput)
def start_profile(self) -> None:
self.llm_engine.start_profile()
@@ -1085,7 +1085,7 @@ class LLM:
def _run_engine(
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
@@ -1098,7 +1098,7 @@ class LLM:
)
# Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():

View File

@@ -18,14 +18,14 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
def _get_embedding(
output: EmbeddingOutput,
output: PoolingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
if encoding_format == "float":
@@ -40,7 +40,7 @@ def _get_embedding(
def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str,
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = []
@@ -169,7 +169,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
@@ -207,7 +207,7 @@ class OpenAIServingEmbedding(OpenAIServing):
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
@@ -215,7 +215,7 @@ class OpenAIServingEmbedding(OpenAIServing):
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)
response = request_output_to_embedding_response(

View File

@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import make_async, merge_async_iterators, random_uuid
@@ -21,7 +21,7 @@ logger = init_logger(__name__)
def request_output_to_score_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str) -> ScoreResponse:
data: List[ScoreResponseData] = []
score = None
@@ -133,7 +133,7 @@ class OpenAIServingScores(OpenAIServing):
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
input_pairs = make_pairs(request.text_1, request.text_2)
@@ -194,7 +194,7 @@ class OpenAIServingScores(OpenAIServing):
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts
try:
@@ -203,7 +203,7 @@ class OpenAIServingScores(OpenAIServing):
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)
response = request_output_to_score_response(