[Experimental] Add multi-LoRA support (#1804)
Co-authored-by: Chen Shen <scv119@gmail.com> Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com> Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from functools import partial
|
||||
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||
Union, AsyncIterator)
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
@@ -203,6 +204,52 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
|
||||
async def encode_request_async(
|
||||
self,
|
||||
request_id: str, # pylint: disable=unused-argument
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
):
|
||||
if prompt_token_ids is None:
|
||||
assert prompt is not None
|
||||
prompt_token_ids = await self.tokenizer.encode_async(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
return prompt_token_ids
|
||||
|
||||
async def add_request_async(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
prompt_token_ids = await self.encode_request_async(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request)
|
||||
|
||||
return self.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prefix_pos=prefix_pos,
|
||||
)
|
||||
|
||||
async def _run_workers_async(
|
||||
self,
|
||||
method: str,
|
||||
@@ -332,7 +379,7 @@ class AsyncLLMEngine:
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(**new_request)
|
||||
else:
|
||||
self.engine.add_request(**new_request)
|
||||
await self.engine.add_request_async(**new_request)
|
||||
|
||||
if finished_requests:
|
||||
await self._engine_abort(finished_requests)
|
||||
@@ -371,6 +418,7 @@ class AsyncLLMEngine:
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> AsyncStream:
|
||||
if self.log_requests:
|
||||
@@ -386,7 +434,8 @@ class AsyncLLMEngine:
|
||||
f"prompt: {shortened_prompt!r}, "
|
||||
f"prefix_pos: {prefix_pos},"
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {shortened_token_ids}.")
|
||||
f"prompt token ids: {shortened_token_ids}, "
|
||||
f"lora_request: {lora_request}.")
|
||||
|
||||
if not self.is_running:
|
||||
if self.start_engine_loop:
|
||||
@@ -398,12 +447,21 @@ class AsyncLLMEngine:
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
prompt_token_ids = await self.engine.encode_request_async(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request)
|
||||
|
||||
stream = self._request_tracker.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prefix_pos=prefix_pos)
|
||||
|
||||
return stream
|
||||
@@ -414,6 +472,7 @@ class AsyncLLMEngine:
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
"""Generate outputs for a request.
|
||||
@@ -429,6 +488,7 @@ class AsyncLLMEngine:
|
||||
request_id: The unique id of the request.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prefix_pos: If not None, we use the given position as the prefix
|
||||
position for each prompt. We will cache the prefix's KV
|
||||
cache and reuse it for the next request with the same prefix.
|
||||
@@ -487,12 +547,15 @@ class AsyncLLMEngine:
|
||||
arrival_time = time.monotonic()
|
||||
|
||||
try:
|
||||
stream = await self.add_request(request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
prefix_pos=prefix_pos)
|
||||
stream = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prefix_pos=prefix_pos,
|
||||
)
|
||||
|
||||
async for request_output in stream:
|
||||
yield request_output
|
||||
|
||||
Reference in New Issue
Block a user