[V1] LoRA Support (#10957)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-02-06 23:02:51 +05:30
committed by GitHub
parent 8108ac841d
commit 467a96a541
16 changed files with 453 additions and 56 deletions

View File

@@ -3,11 +3,12 @@
# Datastructures defining an input batch
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.metadata import SamplingMetadata
@@ -35,6 +36,8 @@ class CachedRequestState:
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
lora_request: Optional[LoRARequest] = None
@property
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)
@@ -161,6 +164,12 @@ class InputBatch:
]
self.prompt_token_ids: Optional[torch.Tensor] = None
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32)
self.lora_id_to_request_ids: Dict[int, Set[str]] = {}
self.lora_id_to_lora_request: Dict[int, LoRARequest] = {}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
@@ -235,6 +244,19 @@ class InputBatch:
if sampling_params.prompt_logprobs:
self.prompt_logprob_reqs.add(req_id)
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
if lora_id not in self.lora_id_to_request_ids:
self.lora_id_to_request_ids[lora_id] = set()
self.request_lora_mapping[req_index] = lora_id
self.lora_id_to_request_ids[lora_id].add(request.req_id)
self.lora_id_to_lora_request[lora_id] = request.lora_request
else:
# No LoRA
self.request_lora_mapping[req_index] = 0
def remove_request(self, req_id: str) -> Optional[int]:
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
@@ -251,6 +273,16 @@ class InputBatch:
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id)
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
self.lora_id_to_request_ids[lora_id].discard(req_id)
if len(self.lora_id_to_request_ids[lora_id]) == 0:
self.lora_id_to_request_ids.pop(lora_id)
self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0
return req_index
def clear(self) -> None:
@@ -266,6 +298,9 @@ class InputBatch:
self.generators.clear()
self.num_logprobs.clear()
self.prompt_logprob_reqs.clear()
self.request_lora_mapping.fill(0)
self.lora_id_to_lora_request.clear()
self.lora_id_to_request_ids.clear()
def condense(self, empty_req_indices: List[int]) -> None:
if self.num_reqs == 0:
@@ -318,6 +353,9 @@ class InputBatch:
if generator is not None:
self.generators[empty_index] = generator
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]
# Decrement last_req_index since it is now empty.
last_req_index -= 1
@@ -401,6 +439,29 @@ class InputBatch:
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray
) -> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping)
token_lora_mapping = tuple(
req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: Set[LoRARequest] = set(
self.lora_id_to_lora_request.values())
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)

View File

@@ -33,6 +33,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
@@ -40,7 +41,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
class GPUModelRunner:
class GPUModelRunner(LoRAModelRunnerMixin):
def __init__(
self,
@@ -279,6 +280,7 @@ class GPUModelRunner:
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
lora_request=new_req_data.lora_request,
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
@@ -372,15 +374,16 @@ class GPUModelRunner:
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = []
num_scheduled_tokens_list: List[int] = []
max_num_scheduled_tokens = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens.append(num_tokens)
num_scheduled_tokens_list.append(num_tokens)
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32)
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
dtype=np.int32)
assert max_num_scheduled_tokens > 0
# Get request indices.
@@ -565,6 +568,11 @@ class GPUModelRunner:
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
)
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
# requests. While we should not sample any token from these partial
# requests, we do so for simplicity. We will ignore the sampled
@@ -867,6 +875,12 @@ class GPUModelRunner:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
self.model = self.load_lora_model(self.model,
self.model_config,
self.scheduler_config,
self.lora_config,
self.device)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
@@ -1005,14 +1019,32 @@ class GPUModelRunner:
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches)
logits = self.model.compute_logits(hidden_states, None)
logits = logits[:self.max_num_tokens]
# TODO(woosuk): Consider the memory usage of the sampler.
torch.cuda.synchronize()
del hidden_states, logits
self.encoder_cache.clear()
# For profile, have maximum num_reqs and that collectively have
# maximum num_tokens.
num_reqs = self.scheduler_config.max_num_seqs
num_tokens = self.max_num_tokens
min_tokens_per_req: int = num_tokens // num_reqs
num_scheduled_tokens_list: List[int] = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
dtype=np.int32)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
with self.maybe_profile_with_lora(self.lora_config,
num_scheduled_tokens):
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens,
dummy_kv_caches)
hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None)
# TODO(woosuk): Consider the memory usage of the sampler.
torch.cuda.synchronize()
del hidden_states, logits
self.encoder_cache.clear()
gc.collect()
def capture_model(self) -> None:

View File

@@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
"""
Define LoRA functionality mixin for model runners.
"""
from contextlib import contextmanager
from typing import Set, Tuple
import numpy as np
import torch.nn as nn
from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.v1.worker.gpu_input_batch import InputBatch
logger = init_logger(__name__)
# Defined as a mixin for GPUModelRunner
class LoRAModelRunnerMixin:
LORA_WARMUP_RANK = 8
def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
scheduler_config: SchedulerConfig,
lora_config: LoRAConfig, device: str) -> nn.Module:
assert supports_lora(
model), f"{model.__class__.__name__} does not support LoRA yet."
if supports_multimodal(model):
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
# It's necessary to distinguish between the max_position_embeddings
# of VLMs and LLMs.
if hasattr(model.config, "max_position_embeddings"):
max_pos_embeddings = model.config.max_position_embeddings
else:
max_pos_embeddings = (
model.config.text_config.max_position_embeddings)
# Add LoRA Manager to the Model Runner
self.lora_manager = LRUCacheWorkerLoRAManager(
scheduler_config.max_num_seqs,
scheduler_config.max_num_batched_tokens,
model_config.get_vocab_size(),
lora_config,
device,
model.embedding_modules,
model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings,
)
return self.lora_manager.create_lora_manager(model)
def _set_active_loras(self, prompt_lora_mapping: Tuple[int, ...],
token_lora_mapping: Tuple[int, ...],
lora_requests: Set[LoRARequest]) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
# We dont make any distinction between prefills and decodes in the
# scheduler. To that effect, set is_prefill to True so we use the
# sgmv punica kernels always.
lora_mapping = LoRAMapping(token_lora_mapping,
prompt_lora_mapping,
is_prefill=True)
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def set_active_loras(self, input_batch: InputBatch,
num_scheduled_tokens: np.ndarray) -> None:
prompt_lora_mapping: Tuple[int, ...] # of size input_batch.num_reqs
token_lora_mapping: Tuple[int,
...] # of size np.sum(num_scheduled_tokens)
lora_requests: Set[LoRARequest]
prompt_lora_mapping, token_lora_mapping, lora_requests = \
input_batch.make_lora_inputs(num_scheduled_tokens)
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
lora_requests)
@contextmanager
def maybe_profile_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
if lora_config is None:
yield
else:
# __enter__ code
assert self.lora_manager is not None, "LoRA is not enabled"
num_reqs = len(num_scheduled_tokens)
num_loras = lora_config.max_loras
# Make prompt lora mapping
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) %
num_loras) + 1
# Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping,
num_scheduled_tokens)
# Make dummy lora requests
lora_requests: Set[LoRARequest] = {
LoRARequest(lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path")
for lora_id in range(1, num_loras + 1)
}
with self.lora_manager.dummy_lora_cache():
# Add the dummy LoRAs here so _set_active_loras doesn't try to
# load from disk.
for lr in lora_requests:
self.lora_manager.add_dummy_lora(
lr, rank=self.LORA_WARMUP_RANK)
self._set_active_loras(tuple(prompt_lora_mapping),
tuple(token_lora_mapping),
lora_requests)
yield
# __exit__ code
self.lora_manager.remove_all_adapters()