[Experimental] Add multi-LoRA support (#1804)

Co-authored-by: Chen Shen <scv119@gmail.com>
Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com>
Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
Antoni Baum
2024-01-24 00:26:37 +01:00
committed by GitHub
parent 9c1352eb57
commit 9b945daaf1
52 changed files with 8035 additions and 126 deletions

View File

@@ -5,8 +5,9 @@ import time
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union)
from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
SchedulerConfig, LoRAConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import record_metrics
@@ -17,7 +18,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
TokenizerGroup)
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
if ray:
@@ -64,6 +65,7 @@ class LLMEngine:
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
placement_group: Optional["PlacementGroup"],
log_stats: bool,
) -> None:
@@ -87,17 +89,13 @@ class LLMEngine:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
self._verify_args()
self.tokenizer = get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
tokenizer_revision=model_config.tokenizer_revision,
revision=model_config.revision)
self._init_tokenizer()
self.seq_counter = Counter()
# Create the parallel GPU workers.
@@ -114,7 +112,7 @@ class LLMEngine:
self._init_cache()
# Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config)
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
# Logging.
self.last_logging_time = 0.0
@@ -123,6 +121,9 @@ class LLMEngine:
# List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = []
def get_tokenizer_for_seq(self, sequence: Sequence):
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
def _init_workers(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
@@ -141,11 +142,24 @@ class LLMEngine:
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
is_driver_worker=True,
)
self._run_workers("init_model")
self._run_workers("load_model")
def _init_tokenizer(self, **tokenizer_init_kwargs):
init_kwargs = dict(
enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer: TokenizerGroup = TokenizerGroup(
self.model_config.tokenizer, **init_kwargs)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
@@ -233,6 +247,7 @@ class LLMEngine:
local_rank,
rank,
distributed_init_method,
lora_config=self.lora_config,
))
driver_rank = 0
@@ -244,6 +259,7 @@ class LLMEngine:
driver_local_rank,
driver_rank,
distributed_init_method,
lora_config=self.lora_config,
is_driver_worker=True,
)
@@ -257,6 +273,10 @@ class LLMEngine:
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache.
@@ -332,6 +352,20 @@ class LLMEngine:
log_stats=not engine_args.disable_log_stats)
return engine
def encode_request(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
def add_request(
self,
request_id: str,
@@ -339,6 +373,7 @@ class LLMEngine:
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
"""Add a request to the engine's request pool.
@@ -386,24 +421,31 @@ class LLMEngine:
>>> # continue the request processing
>>> ...
"""
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.monotonic()
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(prompt)
prompt_token_ids = self.encode_request(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
lora_request)
# Check whether the input specifies prefix
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
if lora_request else 0) if prefix_pos is not None else None
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time, prefix)
arrival_time, lora_request, prefix)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
@@ -453,11 +495,13 @@ class LLMEngine:
current_worst_score = (current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
eos_token_id=self.get_tokenizer_for_seq(
current_worst_seq).eos_token_id))
if early_stopping is False:
highest_attainable_score = (best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
eos_token_id=self.get_tokenizer_for_seq(
best_running_seq).eos_token_id))
else:
assert early_stopping == "never"
if length_penalty > 0.0:
@@ -471,7 +515,8 @@ class LLMEngine:
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.get_tokenizer_for_seq(
best_running_seq).eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
@@ -480,7 +525,8 @@ class LLMEngine:
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
eos_token_id=self.get_tokenizer_for_seq(
best_running_seq).eos_token_id))
return current_worst_score >= highest_attainable_score
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
@@ -571,7 +617,7 @@ class LLMEngine:
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id),
eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
@@ -599,7 +645,7 @@ class LLMEngine:
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id),
eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
reverse=True)
# Check if we can stop the beam search.
@@ -837,7 +883,7 @@ class LLMEngine:
"""Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally(
self.tokenizer,
self.get_tokenizer_for_seq(seq),
all_input_ids=seq.get_token_ids(),
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
@@ -879,11 +925,28 @@ class LLMEngine:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == self.tokenizer.eos_token_id):
if ((not sampling_params.ignore_eos) and seq.get_last_token_id()
== self.get_tokenizer_for_seq(seq).eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED
return
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"add_lora",
lora_request=lora_request,
)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"remove_lora",
lora_id=lora_id,
)
def list_loras(self) -> List[int]:
return self._run_workers("list_loras")
def _run_workers(
self,
method: str,