[Misc] Rename embedding classes to pooling (#10801)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -26,7 +26,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest, LLMGuidedOptions)
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
@@ -679,7 +679,7 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (prompt + optional token ids)
|
||||
@@ -691,7 +691,7 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single (token ids + optional prompt)
|
||||
@@ -704,7 +704,7 @@ class LLM:
|
||||
prompt_token_ids: List[int],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (token ids + optional prompt)
|
||||
@@ -717,7 +717,7 @@ class LLM:
|
||||
prompt_token_ids: List[List[int]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single or multi token ids [pos-only]
|
||||
@@ -728,7 +728,7 @@ class LLM:
|
||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@@ -741,7 +741,7 @@ class LLM:
|
||||
Sequence[PoolingParams]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
@@ -759,7 +759,7 @@ class LLM:
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
This class automatically batches the given prompts, considering
|
||||
@@ -778,7 +778,7 @@ class LLM:
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of ``EmbeddingRequestOutput`` objects containing the
|
||||
A list of ``PoolingRequestOutput`` objects containing the
|
||||
generated embeddings in the same order as the input prompts.
|
||||
|
||||
Note:
|
||||
@@ -821,7 +821,7 @@ class LLM:
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
EmbeddingRequestOutput)
|
||||
PoolingRequestOutput)
|
||||
|
||||
def score(
|
||||
self,
|
||||
@@ -832,7 +832,7 @@ class LLM:
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
"""Generates similarity scores for all pairs <text,text_pair>.
|
||||
|
||||
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
|
||||
@@ -854,7 +854,7 @@ class LLM:
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of ``EmbeddingRequestOutput`` objects containing the
|
||||
A list of ``PoolingRequestOutput`` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
"""
|
||||
task = self.llm_engine.model_config.task
|
||||
@@ -943,7 +943,7 @@ class LLM:
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
EmbeddingRequestOutput)
|
||||
PoolingRequestOutput)
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.llm_engine.start_profile()
|
||||
@@ -1085,7 +1085,7 @@ class LLM:
|
||||
|
||||
def _run_engine(
|
||||
self, *, use_tqdm: bool
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||
@@ -1098,7 +1098,7 @@ class LLM:
|
||||
)
|
||||
|
||||
# Run the engine.
|
||||
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
|
||||
outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
|
||||
total_in_toks = 0
|
||||
total_out_toks = 0
|
||||
while self.llm_engine.has_unfinished_requests():
|
||||
|
||||
@@ -18,14 +18,14 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_embedding(
|
||||
output: EmbeddingOutput,
|
||||
output: PoolingOutput,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> Union[List[float], str]:
|
||||
if encoding_format == "float":
|
||||
@@ -40,7 +40,7 @@ def _get_embedding(
|
||||
|
||||
|
||||
def request_output_to_embedding_response(
|
||||
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
|
||||
final_res_batch: List[PoolingRequestOutput], request_id: str,
|
||||
created_time: int, model_name: str,
|
||||
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
|
||||
data: List[EmbeddingResponseData] = []
|
||||
@@ -169,7 +169,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
@@ -207,7 +207,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||
final_res_batch: List[Optional[PoolingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
@@ -215,7 +215,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
|
||||
final_res_batch_checked = cast(List[PoolingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = request_output_to_embedding_response(
|
||||
|
||||
@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingRequestOutput
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import make_async, merge_async_iterators, random_uuid
|
||||
|
||||
@@ -21,7 +21,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
def request_output_to_score_response(
|
||||
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
|
||||
final_res_batch: List[PoolingRequestOutput], request_id: str,
|
||||
created_time: int, model_name: str) -> ScoreResponse:
|
||||
data: List[ScoreResponseData] = []
|
||||
score = None
|
||||
@@ -133,7 +133,7 @@ class OpenAIServingScores(OpenAIServing):
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
input_pairs = make_pairs(request.text_1, request.text_2)
|
||||
|
||||
@@ -194,7 +194,7 @@ class OpenAIServingScores(OpenAIServing):
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||
final_res_batch: List[Optional[PoolingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
try:
|
||||
@@ -203,7 +203,7 @@ class OpenAIServingScores(OpenAIServing):
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
|
||||
final_res_batch_checked = cast(List[PoolingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = request_output_to_score_response(
|
||||
|
||||
Reference in New Issue
Block a user