[Experimental] Add multi-LoRA support (#1804)

Co-authored-by: Chen Shen <scv119@gmail.com>
Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com>
Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
Antoni Baum
2024-01-24 00:26:37 +01:00
committed by GitHub
parent 9c1352eb57
commit 9b945daaf1
52 changed files with 8035 additions and 126 deletions

View File

@@ -4,6 +4,7 @@ from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator)
from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
@@ -203,6 +204,52 @@ class _AsyncLLMEngine(LLMEngine):
return self._process_model_outputs(output, scheduler_outputs)
async def encode_request_async(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = await self.tokenizer.encode_async(
request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
async def add_request_async(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = await self.encode_request_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
return self.add_request(
request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
)
async def _run_workers_async(
self,
method: str,
@@ -332,7 +379,7 @@ class AsyncLLMEngine:
if self.engine_use_ray:
await self.engine.add_request.remote(**new_request)
else:
self.engine.add_request(**new_request)
await self.engine.add_request_async(**new_request)
if finished_requests:
await self._engine_abort(finished_requests)
@@ -371,6 +418,7 @@ class AsyncLLMEngine:
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> AsyncStream:
if self.log_requests:
@@ -386,7 +434,8 @@ class AsyncLLMEngine:
f"prompt: {shortened_prompt!r}, "
f"prefix_pos: {prefix_pos},"
f"sampling params: {sampling_params}, "
f"prompt token ids: {shortened_token_ids}.")
f"prompt token ids: {shortened_token_ids}, "
f"lora_request: {lora_request}.")
if not self.is_running:
if self.start_engine_loop:
@@ -398,12 +447,21 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = await self.engine.encode_request_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
stream = self._request_tracker.add_request(
request_id,
prompt=prompt,
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos)
return stream
@@ -414,6 +472,7 @@ class AsyncLLMEngine:
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
@@ -429,6 +488,7 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
@@ -487,12 +547,15 @@ class AsyncLLMEngine:
arrival_time = time.monotonic()
try:
stream = await self.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
prefix_pos=prefix_pos)
stream = await self.add_request(
request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
)
async for request_output in stream:
yield request_output