[Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734)
This commit is contained in:
@@ -14,7 +14,8 @@ from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@@ -47,15 +48,16 @@ def _raise_exception_on_finish(
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of RequestOutputs for a request that can be
|
||||
iterated over asynchronously."""
|
||||
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
|
||||
that can be iterated over asynchronously."""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
def put(self, item: Union[RequestOutput, Exception]) -> None:
|
||||
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
|
||||
Exception]) -> None:
|
||||
if self._finished:
|
||||
return
|
||||
self._queue.put_nowait(item)
|
||||
@@ -71,7 +73,7 @@ class AsyncStream:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> RequestOutput:
|
||||
async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
|
||||
result = await self._queue.get()
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
@@ -108,7 +110,8 @@ class RequestTracker:
|
||||
self.abort_request(rid)
|
||||
|
||||
def process_request_output(self,
|
||||
request_output: RequestOutput,
|
||||
request_output: Union[RequestOutput,
|
||||
EmbeddingRequestOutput],
|
||||
*,
|
||||
verbose: bool = False) -> None:
|
||||
"""Process a request output from the engine."""
|
||||
@@ -196,7 +199,8 @@ class RequestTracker:
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
"""Extension of LLMEngine to add async methods."""
|
||||
|
||||
async def step_async(self) -> List[RequestOutput]:
|
||||
async def step_async(
|
||||
self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
The workers are ran asynchronously if possible.
|
||||
|
||||
@@ -251,7 +255,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@@ -270,8 +274,8 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
return self.add_request(request_id,
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data)
|
||||
@@ -511,7 +515,7 @@ class AsyncLLMEngine:
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@@ -528,9 +532,9 @@ class AsyncLLMEngine:
|
||||
max_log_len]
|
||||
logger.info(
|
||||
"Received request %s: prompt: %r, "
|
||||
"sampling_params: %s, prompt_token_ids: %s, "
|
||||
"lora_request: %s.", request_id, shortened_prompt,
|
||||
sampling_params, shortened_token_ids, lora_request)
|
||||
"params: %s, prompt_token_ids: %s, "
|
||||
"lora_request: %s.", request_id, shortened_prompt, params,
|
||||
shortened_token_ids, lora_request)
|
||||
|
||||
if not self.is_running:
|
||||
if self.start_engine_loop:
|
||||
@@ -562,7 +566,7 @@ class AsyncLLMEngine:
|
||||
stream = self._request_tracker.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
params=params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
@@ -597,8 +601,8 @@ class AsyncLLMEngine:
|
||||
multi_modal_data: Multi modal data per request.
|
||||
|
||||
Yields:
|
||||
The output `RequestOutput` objects from the LLMEngine for the
|
||||
request.
|
||||
The output `RequestOutput` objects from the LLMEngine
|
||||
for the request.
|
||||
|
||||
Details:
|
||||
- If the engine is not running, start the background loop,
|
||||
@@ -643,25 +647,123 @@ class AsyncLLMEngine:
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
# Preprocess the request.
|
||||
arrival_time = time.time()
|
||||
|
||||
try:
|
||||
stream = await self.add_request(
|
||||
async for output in self.process_request(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
prompt_token_ids,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
):
|
||||
yield output
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None
|
||||
) -> AsyncIterator[EmbeddingRequestOutput]:
|
||||
"""Generate outputs for a request from an embedding model.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
request into the waiting queue of the LLMEngine and streams the outputs
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
pooling_params: The pooling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
multi_modal_data: Multi modal data per request.
|
||||
|
||||
Yields:
|
||||
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||
for the request.
|
||||
|
||||
Details:
|
||||
- If the engine is not running, start the background loop,
|
||||
which iteratively invokes
|
||||
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
|
||||
to process the waiting requests.
|
||||
- Add the request to the engine's `RequestTracker`.
|
||||
On the next background loop, this request will be sent to
|
||||
the underlying engine.
|
||||
Also, a corresponding `AsyncStream` will be created.
|
||||
- Wait for the request outputs from `AsyncStream` and yield them.
|
||||
|
||||
Example:
|
||||
>>> # Please refer to entrypoints/api_server.py for
|
||||
>>> # the complete example.
|
||||
>>>
|
||||
>>> # initialize the engine and the example input
|
||||
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
>>> example_input = {
|
||||
>>> "input": "What is LLM?",
|
||||
>>> "request_id": 0,
|
||||
>>> }
|
||||
>>>
|
||||
>>> # start the generation
|
||||
>>> results_generator = engine.encode(
|
||||
>>> example_input["input"],
|
||||
>>> PoolingParams(),
|
||||
>>> example_input["request_id"])
|
||||
>>>
|
||||
>>> # get the results
|
||||
>>> final_output = None
|
||||
>>> async for request_output in results_generator:
|
||||
>>> if await request.is_disconnected():
|
||||
>>> # Abort the request if the client disconnects.
|
||||
>>> await engine.abort(request_id)
|
||||
>>> # Return or raise an error
|
||||
>>> ...
|
||||
>>> final_output = request_output
|
||||
>>>
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
async for output in self.process_request(
|
||||
request_id,
|
||||
prompt,
|
||||
pooling_params,
|
||||
prompt_token_ids,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
):
|
||||
yield output
|
||||
|
||||
async def process_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Common logic to process requests with SamplingParams or
|
||||
PoolingParams."""
|
||||
arrival_time = time.time()
|
||||
|
||||
stream = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
try:
|
||||
async for request_output in stream:
|
||||
yield request_output
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
# If there is an exception or coroutine is cancelled, abort the
|
||||
# request.
|
||||
self._abort(request_id)
|
||||
raise e
|
||||
|
||||
|
||||
Reference in New Issue
Block a user