[Model Runner V2] Add LoRAState to consolidate lora logic (#33062)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
47
vllm/v1/worker/gpu/lora_utils.py
Normal file
47
vllm/v1/worker/gpu/lora_utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
NO_LORA_ID = 0
|
||||
|
||||
|
||||
class LoraState:
|
||||
def __init__(self, max_num_reqs: int):
|
||||
self.lora_ids = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.lora_ids.fill(NO_LORA_ID)
|
||||
# req_id -> lora_request
|
||||
self.lora_requests: dict[str, LoRARequest] = {}
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
req_index: int,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> None:
|
||||
if lora_request is not None:
|
||||
self.lora_requests[req_id] = lora_request
|
||||
self.lora_ids[req_index] = lora_request.lora_int_id
|
||||
else:
|
||||
self.lora_ids[req_index] = NO_LORA_ID
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.lora_requests.pop(req_id, None)
|
||||
|
||||
def make_lora_inputs(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
idx_mapping: np.ndarray,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
||||
lora_ids = self.lora_ids[idx_mapping]
|
||||
prompt_lora_mapping = tuple(lora_ids)
|
||||
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
|
||||
|
||||
active_lora_requests: set[LoRARequest] = set()
|
||||
for req_id in req_ids:
|
||||
lora_request = self.lora_requests.get(req_id, None)
|
||||
if lora_request is not None:
|
||||
active_lora_requests.add(lora_request)
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
@@ -51,6 +51,7 @@ from vllm.v1.worker.gpu.kv_connector import (
|
||||
KVConnector,
|
||||
get_kv_connector,
|
||||
)
|
||||
from vllm.v1.worker.gpu.lora_utils import LoraState
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
@@ -168,6 +169,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
# LoRA-related workers.
|
||||
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
|
||||
|
||||
# Buffers for CPU-to-GPU copies.
|
||||
self.tmp_idx_mapping = UvaBufferPool(self.max_num_reqs, torch.int32)
|
||||
@@ -426,6 +429,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.remove_request(req_id)
|
||||
self.prompt_logprobs_worker.remove_request(req_id)
|
||||
self.lora_state.remove_request(req_id)
|
||||
|
||||
def free_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
if self.supports_mm_inputs:
|
||||
@@ -444,7 +448,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
prompt_len=prompt_len,
|
||||
prefill_token_ids=new_req_data.prefill_token_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
|
||||
@@ -469,6 +472,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.prompt_logprobs_worker.add_request(
|
||||
req_id, req_index, new_req_data.sampling_params
|
||||
)
|
||||
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
|
||||
|
||||
if scheduler_output.scheduled_new_reqs:
|
||||
self.req_states.apply_staged_writes()
|
||||
@@ -841,7 +845,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
if self.lora_config:
|
||||
# Activate LoRA adapters.
|
||||
lora_inputs = self.req_states.make_lora_inputs(
|
||||
lora_inputs = self.lora_state.make_lora_inputs(
|
||||
input_batch.req_ids,
|
||||
input_batch.idx_mapping_np,
|
||||
input_batch.num_scheduled_tokens,
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
|
||||
NO_LORA_ID = 0
|
||||
|
||||
|
||||
class RequestState:
|
||||
def __init__(
|
||||
@@ -31,7 +26,6 @@ class RequestState:
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_reqs))
|
||||
self.extra_data: dict[str, ExtraData] = {}
|
||||
|
||||
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
|
||||
@@ -70,10 +64,6 @@ class RequestState:
|
||||
self.max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# LoRA.
|
||||
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.lora_ids.fill(NO_LORA_ID)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
@@ -84,13 +74,11 @@ class RequestState:
|
||||
prompt_len: int,
|
||||
prefill_token_ids: list[int],
|
||||
num_computed_tokens: int,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> None:
|
||||
assert len(self.free_indices) > 0, "No free indices"
|
||||
req_idx = self.free_indices.pop()
|
||||
self.req_id_to_index[req_id] = req_idx
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
self.extra_data[req_id] = ExtraData(lora_request)
|
||||
|
||||
self.prompt_len[req_idx] = prompt_len
|
||||
prefill_len = len(prefill_token_ids)
|
||||
@@ -102,43 +90,15 @@ class RequestState:
|
||||
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
|
||||
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
|
||||
|
||||
if lora_request is not None:
|
||||
self.lora_ids[req_idx] = lora_request.lora_int_id
|
||||
else:
|
||||
self.lora_ids[req_idx] = NO_LORA_ID
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.prefill_len.copy_to_uva()
|
||||
self.prefill_token_ids.apply_write()
|
||||
self.num_computed_tokens.apply_write()
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.extra_data.pop(req_id, None)
|
||||
req_idx = self.req_id_to_index.pop(req_id, None)
|
||||
if req_idx is None:
|
||||
# Request not found.
|
||||
return
|
||||
self.index_to_req_id.pop(req_idx, None)
|
||||
self.free_indices.append(req_idx)
|
||||
|
||||
def make_lora_inputs(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
idx_mapping: np.ndarray,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
||||
lora_ids = self.lora_ids[idx_mapping]
|
||||
prompt_lora_mapping = tuple(lora_ids)
|
||||
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
|
||||
|
||||
active_lora_requests: set[LoRARequest] = set()
|
||||
for req_id in req_ids:
|
||||
lora_request = self.extra_data[req_id].lora_request
|
||||
if lora_request is not None:
|
||||
active_lora_requests.add(lora_request)
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraData:
|
||||
lora_request: LoRARequest | None
|
||||
|
||||
Reference in New Issue
Block a user