[Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734)

This commit is contained in:
Chang Su
2024-05-11 11:30:37 -07:00
committed by GitHub
parent 4e12131089
commit e254497b66
38 changed files with 1627 additions and 160 deletions

View File

@@ -14,7 +14,8 @@ from vllm.engine.llm_engine import LLMEngine
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
from vllm.usage.usage_lib import UsageContext
@@ -47,15 +48,16 @@ def _raise_exception_on_finish(
class AsyncStream:
"""A stream of RequestOutputs for a request that can be
iterated over asynchronously."""
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
that can be iterated over asynchronously."""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, Exception]) -> None:
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
Exception]) -> None:
if self._finished:
return
self._queue.put_nowait(item)
@@ -71,7 +73,7 @@ class AsyncStream:
def __aiter__(self):
return self
async def __anext__(self) -> RequestOutput:
async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
result = await self._queue.get()
if isinstance(result, Exception):
raise result
@@ -108,7 +110,8 @@ class RequestTracker:
self.abort_request(rid)
def process_request_output(self,
request_output: RequestOutput,
request_output: Union[RequestOutput,
EmbeddingRequestOutput],
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
@@ -196,7 +199,8 @@ class RequestTracker:
class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
async def step_async(self) -> List[RequestOutput]:
async def step_async(
self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
@@ -251,7 +255,7 @@ class _AsyncLLMEngine(LLMEngine):
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
@@ -270,8 +274,8 @@ class _AsyncLLMEngine(LLMEngine):
return self.add_request(request_id,
prompt=prompt,
params=params,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
@@ -511,7 +515,7 @@ class AsyncLLMEngine:
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
@@ -528,9 +532,9 @@ class AsyncLLMEngine:
max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"sampling_params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt,
sampling_params, shortened_token_ids, lora_request)
"params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt, params,
shortened_token_ids, lora_request)
if not self.is_running:
if self.start_engine_loop:
@@ -562,7 +566,7 @@ class AsyncLLMEngine:
stream = self._request_tracker.add_request(
request_id,
prompt=prompt,
sampling_params=sampling_params,
params=params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
@@ -597,8 +601,8 @@ class AsyncLLMEngine:
multi_modal_data: Multi modal data per request.
Yields:
The output `RequestOutput` objects from the LLMEngine for the
request.
The output `RequestOutput` objects from the LLMEngine
for the request.
Details:
- If the engine is not running, start the background loop,
@@ -643,25 +647,123 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
# Preprocess the request.
arrival_time = time.time()
try:
stream = await self.add_request(
async for output in self.process_request(
request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
prompt_token_ids,
lora_request,
multi_modal_data,
):
yield output
async def encode(
self,
prompt: Optional[str],
pooling_params: PoolingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
>>> # Please refer to entrypoints/api_server.py for
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "input": "What is LLM?",
>>> "request_id": 0,
>>> }
>>>
>>> # start the generation
>>> results_generator = engine.encode(
>>> example_input["input"],
>>> PoolingParams(),
>>> example_input["request_id"])
>>>
>>> # get the results
>>> final_output = None
>>> async for request_output in results_generator:
>>> if await request.is_disconnected():
>>> # Abort the request if the client disconnects.
>>> await engine.abort(request_id)
>>> # Return or raise an error
>>> ...
>>> final_output = request_output
>>>
>>> # Process and return the final output
>>> ...
"""
async for output in self.process_request(
request_id,
prompt,
pooling_params,
prompt_token_ids,
lora_request,
multi_modal_data,
):
yield output
async def process_request(
self,
request_id: str,
prompt: Optional[str],
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
arrival_time = time.time()
stream = await self.add_request(
request_id,
prompt,
params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
try:
async for request_output in stream:
yield request_output
except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the
# request.
self._abort(request_id)
raise e