[Misc] Rename embedding classes to pooling (#10801)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -25,7 +25,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -74,7 +74,7 @@ STOP_ITERATION = Exception() # Sentinel
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
|
||||
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
|
||||
that can be iterated over asynchronously via an async generator."""
|
||||
|
||||
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
|
||||
@@ -83,7 +83,7 @@ class AsyncStream:
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
|
||||
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
|
||||
Exception]) -> None:
|
||||
if not self._finished:
|
||||
self._queue.put_nowait(item)
|
||||
@@ -103,7 +103,7 @@ class AsyncStream:
|
||||
|
||||
async def generator(
|
||||
self
|
||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||
try:
|
||||
while True:
|
||||
result = await self._queue.get()
|
||||
@@ -154,7 +154,7 @@ class RequestTracker:
|
||||
|
||||
def process_request_output(self,
|
||||
request_output: Union[RequestOutput,
|
||||
EmbeddingRequestOutput],
|
||||
PoolingRequestOutput],
|
||||
*,
|
||||
verbose: bool = False) -> None:
|
||||
"""Process a request output from the engine."""
|
||||
@@ -265,7 +265,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
async def step_async(
|
||||
self, virtual_engine: int
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
The workers are ran asynchronously if possible.
|
||||
|
||||
@@ -907,7 +907,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> Coroutine[None, None, AsyncGenerator[Union[
|
||||
RequestOutput, EmbeddingRequestOutput], None]]:
|
||||
RequestOutput, PoolingRequestOutput], None]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@@ -922,7 +922,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> Coroutine[None, None, AsyncGenerator[Union[
|
||||
RequestOutput, EmbeddingRequestOutput], None]]:
|
||||
RequestOutput, PoolingRequestOutput], None]]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
@@ -941,7 +941,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
priority: int = 0,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||
if inputs is not None:
|
||||
prompt = inputs
|
||||
assert prompt is not None and params is not None
|
||||
@@ -1070,7 +1070,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""Generate outputs for a request from an embedding model.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
@@ -1088,7 +1088,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
Only applicable with priority scheduling.
|
||||
|
||||
Yields:
|
||||
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||
The output `PoolingRequestOutput` objects from the LLMEngine
|
||||
for the request.
|
||||
|
||||
Details:
|
||||
@@ -1141,7 +1141,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
):
|
||||
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
|
||||
yield LLMEngine.validate_output(output, PoolingRequestOutput)
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
|
||||
Reference in New Issue
Block a user