[Misc] Rename embedding classes to pooling (#10801)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-01 14:36:51 +08:00
committed by GitHub
parent f877a7d12a
commit d2f058e76c
25 changed files with 166 additions and 123 deletions

View File

@@ -25,7 +25,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
@@ -74,7 +74,7 @@ STOP_ITERATION = Exception() # Sentinel
class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
@@ -83,7 +83,7 @@ class AsyncStream:
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None:
if not self._finished:
self._queue.put_nowait(item)
@@ -103,7 +103,7 @@ class AsyncStream:
async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
try:
while True:
result = await self._queue.get()
@@ -154,7 +154,7 @@ class RequestTracker:
def process_request_output(self,
request_output: Union[RequestOutput,
EmbeddingRequestOutput],
PoolingRequestOutput],
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
@@ -265,7 +265,7 @@ class _AsyncLLMEngine(LLMEngine):
async def step_async(
self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
@@ -907,7 +907,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
RequestOutput, PoolingRequestOutput], None]]:
...
@overload
@@ -922,7 +922,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
RequestOutput, PoolingRequestOutput], None]]:
...
@deprecate_kwargs(
@@ -941,7 +941,7 @@ class AsyncLLMEngine(EngineClient):
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
@@ -1070,7 +1070,7 @@ class AsyncLLMEngine(EngineClient):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
@@ -1088,7 +1088,7 @@ class AsyncLLMEngine(EngineClient):
Only applicable with priority scheduling.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.
Details:
@@ -1141,7 +1141,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers=trace_headers,
priority=priority,
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
yield LLMEngine.validate_output(output, PoolingRequestOutput)
async def abort(self, request_id: str) -> None:
"""Abort a request.