[Misc] Rename embedding classes to pooling (#10801)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-01 14:36:51 +08:00
committed by GitHub
parent f877a7d12a
commit d2f058e76c
25 changed files with 166 additions and 123 deletions

View File

@@ -26,7 +26,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions)
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
@@ -679,7 +679,7 @@ class LLM:
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
@@ -691,7 +691,7 @@ class LLM:
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
@@ -704,7 +704,7 @@ class LLM:
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
@@ -717,7 +717,7 @@ class LLM:
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
@@ -728,7 +728,7 @@ class LLM:
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload
@@ -741,7 +741,7 @@ class LLM:
Sequence[PoolingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@deprecate_kwargs(
@@ -759,7 +759,7 @@ class LLM:
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
"""Generates the completions for the input prompts.
This class automatically batches the given prompts, considering
@@ -778,7 +778,7 @@ class LLM:
generation, if any.
Returns:
A list of ``EmbeddingRequestOutput`` objects containing the
A list of ``PoolingRequestOutput`` objects containing the
generated embeddings in the same order as the input prompts.
Note:
@@ -821,7 +821,7 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)
PoolingRequestOutput)
def score(
self,
@@ -832,7 +832,7 @@ class LLM:
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
"""Generates similarity scores for all pairs <text,text_pair>.
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
@@ -854,7 +854,7 @@ class LLM:
generation, if any.
Returns:
A list of ``EmbeddingRequestOutput`` objects containing the
A list of ``PoolingRequestOutput`` objects containing the
generated scores in the same order as the input prompts.
"""
task = self.llm_engine.model_config.task
@@ -943,7 +943,7 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)
PoolingRequestOutput)
def start_profile(self) -> None:
self.llm_engine.start_profile()
@@ -1085,7 +1085,7 @@ class LLM:
def _run_engine(
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
@@ -1098,7 +1098,7 @@ class LLM:
)
# Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():