[Experimental] Add multi-LoRA support (#1804)
Co-authored-by: Chen Shen <scv119@gmail.com> Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com> Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
@@ -5,8 +5,9 @@ import time
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
SchedulerConfig, LoRAConfig)
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics import record_metrics
|
||||
@@ -17,7 +18,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
TokenizerGroup)
|
||||
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
|
||||
|
||||
if ray:
|
||||
@@ -64,6 +65,7 @@ class LLMEngine:
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
placement_group: Optional["PlacementGroup"],
|
||||
log_stats: bool,
|
||||
) -> None:
|
||||
@@ -87,17 +89,13 @@ class LLMEngine:
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.log_stats = log_stats
|
||||
self._verify_args()
|
||||
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
tokenizer_revision=model_config.tokenizer_revision,
|
||||
revision=model_config.revision)
|
||||
self._init_tokenizer()
|
||||
self.seq_counter = Counter()
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
@@ -114,7 +112,7 @@ class LLMEngine:
|
||||
self._init_cache()
|
||||
|
||||
# Create the scheduler.
|
||||
self.scheduler = Scheduler(scheduler_config, cache_config)
|
||||
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
||||
|
||||
# Logging.
|
||||
self.last_logging_time = 0.0
|
||||
@@ -123,6 +121,9 @@ class LLMEngine:
|
||||
# List of (timestamp, num_tokens)
|
||||
self.num_generation_tokens: List[Tuple[float, int]] = []
|
||||
|
||||
def get_tokenizer_for_seq(self, sequence: Sequence):
|
||||
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
def _init_workers(self):
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
@@ -141,11 +142,24 @@ class LLMEngine:
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
self._run_workers("init_model")
|
||||
self._run_workers("load_model")
|
||||
|
||||
def _init_tokenizer(self, **tokenizer_init_kwargs):
|
||||
init_kwargs = dict(
|
||||
enable_lora=bool(self.lora_config),
|
||||
max_num_seqs=self.scheduler_config.max_num_seqs,
|
||||
max_input_length=None,
|
||||
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
revision=self.model_config.tokenizer_revision)
|
||||
init_kwargs.update(tokenizer_init_kwargs)
|
||||
self.tokenizer: TokenizerGroup = TokenizerGroup(
|
||||
self.model_config.tokenizer, **init_kwargs)
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
if self.parallel_config.tensor_parallel_size == 1:
|
||||
@@ -233,6 +247,7 @@ class LLMEngine:
|
||||
local_rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
))
|
||||
|
||||
driver_rank = 0
|
||||
@@ -244,6 +259,7 @@ class LLMEngine:
|
||||
driver_local_rank,
|
||||
driver_rank,
|
||||
distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
@@ -257,6 +273,10 @@ class LLMEngine:
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
if self.lora_config:
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
self.scheduler_config)
|
||||
|
||||
def _init_cache(self) -> None:
|
||||
"""Profiles the memory usage and initializes the KV cache.
|
||||
@@ -332,6 +352,20 @@ class LLMEngine:
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
return engine
|
||||
|
||||
def encode_request(
|
||||
self,
|
||||
request_id: str, # pylint: disable=unused-argument
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
):
|
||||
if prompt_token_ids is None:
|
||||
assert prompt is not None
|
||||
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
return prompt_token_ids
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
@@ -339,6 +373,7 @@ class LLMEngine:
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Add a request to the engine's request pool.
|
||||
@@ -386,24 +421,31 @@ class LLMEngine:
|
||||
>>> # continue the request processing
|
||||
>>> ...
|
||||
"""
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
if arrival_time is None:
|
||||
arrival_time = time.monotonic()
|
||||
if prompt_token_ids is None:
|
||||
assert prompt is not None
|
||||
prompt_token_ids = self.tokenizer.encode(prompt)
|
||||
prompt_token_ids = self.encode_request(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request)
|
||||
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seq_id = next(self.seq_counter)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
||||
lora_request)
|
||||
|
||||
# Check whether the input specifies prefix
|
||||
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
|
||||
prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None
|
||||
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
|
||||
if lora_request else 0) if prefix_pos is not None else None
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||
arrival_time, prefix)
|
||||
arrival_time, lora_request, prefix)
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
@@ -453,11 +495,13 @@ class LLMEngine:
|
||||
|
||||
current_worst_score = (current_worst_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id))
|
||||
eos_token_id=self.get_tokenizer_for_seq(
|
||||
current_worst_seq).eos_token_id))
|
||||
if early_stopping is False:
|
||||
highest_attainable_score = (best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id))
|
||||
eos_token_id=self.get_tokenizer_for_seq(
|
||||
best_running_seq).eos_token_id))
|
||||
else:
|
||||
assert early_stopping == "never"
|
||||
if length_penalty > 0.0:
|
||||
@@ -471,7 +515,8 @@ class LLMEngine:
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
eos_token_id=self.get_tokenizer_for_seq(
|
||||
best_running_seq).eos_token_id,
|
||||
seq_len=max_possible_length))
|
||||
else:
|
||||
# Otherwise, beam search will prefer shorter sequences. The
|
||||
@@ -480,7 +525,8 @@ class LLMEngine:
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id))
|
||||
eos_token_id=self.get_tokenizer_for_seq(
|
||||
best_running_seq).eos_token_id))
|
||||
return current_worst_score >= highest_attainable_score
|
||||
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
@@ -571,7 +617,7 @@ class LLMEngine:
|
||||
# Sort the finished sequences by their scores.
|
||||
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id),
|
||||
eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
|
||||
reverse=True)
|
||||
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
||||
if is_new:
|
||||
@@ -599,7 +645,7 @@ class LLMEngine:
|
||||
# Sort the running sequences by their scores.
|
||||
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id),
|
||||
eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
|
||||
reverse=True)
|
||||
|
||||
# Check if we can stop the beam search.
|
||||
@@ -837,7 +883,7 @@ class LLMEngine:
|
||||
"""Decodes the new token for a sequence."""
|
||||
(new_tokens, new_output_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
self.tokenizer,
|
||||
self.get_tokenizer_for_seq(seq),
|
||||
all_input_ids=seq.get_token_ids(),
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
@@ -879,11 +925,28 @@ class LLMEngine:
|
||||
return
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == self.tokenizer.eos_token_id):
|
||||
if ((not sampling_params.ignore_eos) and seq.get_last_token_id()
|
||||
== self.get_tokenizer_for_seq(seq).eos_token_id):
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||
return self._run_workers(
|
||||
"add_lora",
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self._run_workers(
|
||||
"remove_lora",
|
||||
lora_id=lora_id,
|
||||
)
|
||||
|
||||
def list_loras(self) -> List[int]:
|
||||
return self._run_workers("list_loras")
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
|
||||
Reference in New Issue
Block a user