[V1] LoRA Support (#10957)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
8108ac841d
commit
467a96a541
@@ -3,11 +3,12 @@
|
||||
# Datastructures defining an input batch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
@@ -35,6 +36,8 @@ class CachedRequestState:
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[int] = None
|
||||
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return len(self.prompt_token_ids) + len(self.output_token_ids)
|
||||
@@ -161,6 +164,12 @@ class InputBatch:
|
||||
]
|
||||
self.prompt_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
||||
dtype=np.int32)
|
||||
self.lora_id_to_request_ids: Dict[int, Set[str]] = {}
|
||||
self.lora_id_to_lora_request: Dict[int, LoRARequest] = {}
|
||||
|
||||
# req_index -> generator
|
||||
# NOTE(woosuk): The indices of the requests that do not have their own
|
||||
# generator should not be included in the dictionary.
|
||||
@@ -235,6 +244,19 @@ class InputBatch:
|
||||
if sampling_params.prompt_logprobs:
|
||||
self.prompt_logprob_reqs.add(req_id)
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
if lora_id not in self.lora_id_to_request_ids:
|
||||
self.lora_id_to_request_ids[lora_id] = set()
|
||||
|
||||
self.request_lora_mapping[req_index] = lora_id
|
||||
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
||||
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
||||
else:
|
||||
# No LoRA
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
def remove_request(self, req_id: str) -> Optional[int]:
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
@@ -251,6 +273,16 @@ class InputBatch:
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.prompt_logprob_reqs.discard(req_id)
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
if lora_id != 0:
|
||||
self.lora_id_to_request_ids[lora_id].discard(req_id)
|
||||
if len(self.lora_id_to_request_ids[lora_id]) == 0:
|
||||
self.lora_id_to_request_ids.pop(lora_id)
|
||||
self.lora_id_to_lora_request.pop(lora_id)
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
return req_index
|
||||
|
||||
def clear(self) -> None:
|
||||
@@ -266,6 +298,9 @@ class InputBatch:
|
||||
self.generators.clear()
|
||||
self.num_logprobs.clear()
|
||||
self.prompt_logprob_reqs.clear()
|
||||
self.request_lora_mapping.fill(0)
|
||||
self.lora_id_to_lora_request.clear()
|
||||
self.lora_id_to_request_ids.clear()
|
||||
|
||||
def condense(self, empty_req_indices: List[int]) -> None:
|
||||
if self.num_reqs == 0:
|
||||
@@ -318,6 +353,9 @@ class InputBatch:
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index]
|
||||
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
@@ -401,6 +439,29 @@ class InputBatch:
|
||||
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
||||
non_blocking=True)
|
||||
|
||||
def make_lora_inputs(
|
||||
self, num_scheduled_tokens: np.ndarray
|
||||
) -> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]:
|
||||
"""
|
||||
Given the num_scheduled_tokens for each request in the batch, return
|
||||
datastructures used to activate the current LoRAs.
|
||||
Returns:
|
||||
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
||||
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
||||
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
||||
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
||||
3. lora_requests: Set of relevant LoRA requests.
|
||||
"""
|
||||
|
||||
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
|
||||
prompt_lora_mapping = tuple(req_lora_mapping)
|
||||
token_lora_mapping = tuple(
|
||||
req_lora_mapping.repeat(num_scheduled_tokens))
|
||||
active_lora_requests: Set[LoRARequest] = set(
|
||||
self.lora_id_to_lora_request.values())
|
||||
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
@@ -33,6 +33,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
@@ -40,7 +41,7 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUModelRunner:
|
||||
class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -279,6 +280,7 @@ class GPUModelRunner:
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
@@ -372,15 +374,16 @@ class GPUModelRunner:
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
# TODO: The Python loop can be slow. Optimize.
|
||||
num_scheduled_tokens = []
|
||||
num_scheduled_tokens_list: List[int] = []
|
||||
max_num_scheduled_tokens = 0
|
||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||
assert req_id is not None
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_scheduled_tokens.append(num_tokens)
|
||||
num_scheduled_tokens_list.append(num_tokens)
|
||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||
num_tokens)
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32)
|
||||
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
assert max_num_scheduled_tokens > 0
|
||||
|
||||
# Get request indices.
|
||||
@@ -565,6 +568,11 @@ class GPUModelRunner:
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
)
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
|
||||
# requests. While we should not sample any token from these partial
|
||||
# requests, we do so for simplicity. We will ignore the sampled
|
||||
@@ -867,6 +875,12 @@ class GPUModelRunner:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model,
|
||||
self.model_config,
|
||||
self.scheduler_config,
|
||||
self.lora_config,
|
||||
self.device)
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
@@ -1005,14 +1019,32 @@ class GPUModelRunner:
|
||||
# Cache the dummy encoder outputs.
|
||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches)
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
logits = logits[:self.max_num_tokens]
|
||||
# TODO(woosuk): Consider the memory usage of the sampler.
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, logits
|
||||
self.encoder_cache.clear()
|
||||
# For profile, have maximum num_reqs and that collectively have
|
||||
# maximum num_tokens.
|
||||
num_reqs = self.scheduler_config.max_num_seqs
|
||||
num_tokens = self.max_num_tokens
|
||||
min_tokens_per_req: int = num_tokens // num_reqs
|
||||
|
||||
num_scheduled_tokens_list: List[int] = [min_tokens_per_req] * num_reqs
|
||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
|
||||
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||
|
||||
with self.maybe_profile_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.max_num_tokens,
|
||||
dummy_kv_caches)
|
||||
hidden_states = hidden_states[logit_indices]
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
# TODO(woosuk): Consider the memory usage of the sampler.
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, logits
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
def capture_model(self) -> None:
|
||||
|
||||
129
vllm/v1/worker/lora_model_runner_mixin.py
Normal file
129
vllm/v1/worker/lora_model_runner_mixin.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Define LoRA functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a mixin for GPUModelRunner
|
||||
class LoRAModelRunnerMixin:
|
||||
|
||||
LORA_WARMUP_RANK = 8
|
||||
|
||||
def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_config: LoRAConfig, device: str) -> nn.Module:
|
||||
|
||||
assert supports_lora(
|
||||
model), f"{model.__class__.__name__} does not support LoRA yet."
|
||||
|
||||
if supports_multimodal(model):
|
||||
logger.warning("Regarding multimodal models, vLLM currently "
|
||||
"only supports adding LoRA to language model.")
|
||||
|
||||
# It's necessary to distinguish between the max_position_embeddings
|
||||
# of VLMs and LLMs.
|
||||
if hasattr(model.config, "max_position_embeddings"):
|
||||
max_pos_embeddings = model.config.max_position_embeddings
|
||||
else:
|
||||
max_pos_embeddings = (
|
||||
model.config.text_config.max_position_embeddings)
|
||||
|
||||
# Add LoRA Manager to the Model Runner
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
scheduler_config.max_num_seqs,
|
||||
scheduler_config.max_num_batched_tokens,
|
||||
model_config.get_vocab_size(),
|
||||
lora_config,
|
||||
device,
|
||||
model.embedding_modules,
|
||||
model.embedding_padding_modules,
|
||||
max_position_embeddings=max_pos_embeddings,
|
||||
)
|
||||
return self.lora_manager.create_lora_manager(model)
|
||||
|
||||
def _set_active_loras(self, prompt_lora_mapping: Tuple[int, ...],
|
||||
token_lora_mapping: Tuple[int, ...],
|
||||
lora_requests: Set[LoRARequest]) -> None:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
|
||||
# We dont make any distinction between prefills and decodes in the
|
||||
# scheduler. To that effect, set is_prefill to True so we use the
|
||||
# sgmv punica kernels always.
|
||||
lora_mapping = LoRAMapping(token_lora_mapping,
|
||||
prompt_lora_mapping,
|
||||
is_prefill=True)
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
def set_active_loras(self, input_batch: InputBatch,
|
||||
num_scheduled_tokens: np.ndarray) -> None:
|
||||
|
||||
prompt_lora_mapping: Tuple[int, ...] # of size input_batch.num_reqs
|
||||
token_lora_mapping: Tuple[int,
|
||||
...] # of size np.sum(num_scheduled_tokens)
|
||||
lora_requests: Set[LoRARequest]
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests = \
|
||||
input_batch.make_lora_inputs(num_scheduled_tokens)
|
||||
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
||||
lora_requests)
|
||||
|
||||
@contextmanager
|
||||
def maybe_profile_with_lora(self, lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
# __enter__ code
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_reqs = len(num_scheduled_tokens)
|
||||
num_loras = lora_config.max_loras
|
||||
|
||||
# Make prompt lora mapping
|
||||
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
||||
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) %
|
||||
num_loras) + 1
|
||||
|
||||
# Make token lora mapping
|
||||
token_lora_mapping = np.repeat(prompt_lora_mapping,
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Make dummy lora requests
|
||||
lora_requests: Set[LoRARequest] = {
|
||||
LoRARequest(lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path")
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
}
|
||||
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
# Add the dummy LoRAs here so _set_active_loras doesn't try to
|
||||
# load from disk.
|
||||
for lr in lora_requests:
|
||||
self.lora_manager.add_dummy_lora(
|
||||
lr, rank=self.LORA_WARMUP_RANK)
|
||||
|
||||
self._set_active_loras(tuple(prompt_lora_mapping),
|
||||
tuple(token_lora_mapping),
|
||||
lora_requests)
|
||||
|
||||
yield
|
||||
|
||||
# __exit__ code
|
||||
self.lora_manager.remove_all_adapters()
|
||||
Reference in New Issue
Block a user