From a9b53dd435bb82f311f340ebebc15e62b9624a9d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 25 Jan 2026 22:21:12 -0800 Subject: [PATCH] [Model Runner V2] Add LoRAState to consolidate lora logic (#33062) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/lora_utils.py | 47 ++++++++++++++++++++++++++++++ vllm/v1/worker/gpu/model_runner.py | 8 +++-- vllm/v1/worker/gpu/states.py | 40 ------------------------- 3 files changed, 53 insertions(+), 42 deletions(-) create mode 100644 vllm/v1/worker/gpu/lora_utils.py diff --git a/vllm/v1/worker/gpu/lora_utils.py b/vllm/v1/worker/gpu/lora_utils.py new file mode 100644 index 000000000..146e51e8b --- /dev/null +++ b/vllm/v1/worker/gpu/lora_utils.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np + +from vllm.lora.request import LoRARequest + +NO_LORA_ID = 0 + + +class LoraState: + def __init__(self, max_num_reqs: int): + self.lora_ids = np.zeros(max_num_reqs, dtype=np.int32) + self.lora_ids.fill(NO_LORA_ID) + # req_id -> lora_request + self.lora_requests: dict[str, LoRARequest] = {} + + def add_request( + self, + req_id: str, + req_index: int, + lora_request: LoRARequest | None, + ) -> None: + if lora_request is not None: + self.lora_requests[req_id] = lora_request + self.lora_ids[req_index] = lora_request.lora_int_id + else: + self.lora_ids[req_index] = NO_LORA_ID + + def remove_request(self, req_id: str) -> None: + self.lora_requests.pop(req_id, None) + + def make_lora_inputs( + self, + req_ids: list[str], + idx_mapping: np.ndarray, + num_scheduled_tokens: np.ndarray, + ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: + lora_ids = self.lora_ids[idx_mapping] + prompt_lora_mapping = tuple(lora_ids) + token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens)) + + active_lora_requests: set[LoRARequest] = set() + for req_id in req_ids: + lora_request = self.lora_requests.get(req_id, None) + if lora_request is not None: + active_lora_requests.add(lora_request) + return prompt_lora_mapping, token_lora_mapping, active_lora_requests diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 94601d4c6..a1c80fce0 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -51,6 +51,7 @@ from vllm.v1.worker.gpu.kv_connector import ( KVConnector, get_kv_connector, ) +from vllm.v1.worker.gpu.lora_utils import LoraState from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState from vllm.v1.worker.gpu.sample.output import SamplerOutput @@ -168,6 +169,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1), vocab_size=self.vocab_size, ) + # LoRA-related workers. + self.lora_state = LoraState(max_num_reqs=self.max_num_reqs) # Buffers for CPU-to-GPU copies. self.tmp_idx_mapping = UvaBufferPool(self.max_num_reqs, torch.int32) @@ -426,6 +429,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.supports_mm_inputs: self.encoder_runner.remove_request(req_id) self.prompt_logprobs_worker.remove_request(req_id) + self.lora_state.remove_request(req_id) def free_states(self, scheduler_output: SchedulerOutput) -> None: if self.supports_mm_inputs: @@ -444,7 +448,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): prompt_len=prompt_len, prefill_token_ids=new_req_data.prefill_token_ids, num_computed_tokens=new_req_data.num_computed_tokens, - lora_request=new_req_data.lora_request, ) req_index = self.req_states.req_id_to_index[req_id] @@ -469,6 +472,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.prompt_logprobs_worker.add_request( req_id, req_index, new_req_data.sampling_params ) + self.lora_state.add_request(req_id, req_index, new_req_data.lora_request) if scheduler_output.scheduled_new_reqs: self.req_states.apply_staged_writes() @@ -841,7 +845,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) if self.lora_config: # Activate LoRA adapters. - lora_inputs = self.req_states.make_lora_inputs( + lora_inputs = self.lora_state.make_lora_inputs( input_batch.req_ids, input_batch.idx_mapping_np, input_batch.num_scheduled_tokens, diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index b73d41a78..5379aae72 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -1,15 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass - import numpy as np import torch -from vllm.lora.request import LoRARequest from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor -NO_LORA_ID = 0 - class RequestState: def __init__( @@ -31,7 +26,6 @@ class RequestState: self.req_id_to_index: dict[str, int] = {} self.index_to_req_id: dict[int, str] = {} self.free_indices = list(range(max_num_reqs)) - self.extra_data: dict[str, ExtraData] = {} self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32) # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) @@ -70,10 +64,6 @@ class RequestState: self.max_num_reqs, dtype=torch.int32, device=device ) - # LoRA. - self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32) - self.lora_ids.fill(NO_LORA_ID) - @property def num_reqs(self) -> int: return len(self.req_id_to_index) @@ -84,13 +74,11 @@ class RequestState: prompt_len: int, prefill_token_ids: list[int], num_computed_tokens: int, - lora_request: LoRARequest | None, ) -> None: assert len(self.free_indices) > 0, "No free indices" req_idx = self.free_indices.pop() self.req_id_to_index[req_id] = req_idx self.index_to_req_id[req_idx] = req_id - self.extra_data[req_id] = ExtraData(lora_request) self.prompt_len[req_idx] = prompt_len prefill_len = len(prefill_token_ids) @@ -102,43 +90,15 @@ class RequestState: self.num_computed_prefill_tokens[req_idx] = num_computed_tokens self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens) - if lora_request is not None: - self.lora_ids[req_idx] = lora_request.lora_int_id - else: - self.lora_ids[req_idx] = NO_LORA_ID - def apply_staged_writes(self) -> None: self.prefill_len.copy_to_uva() self.prefill_token_ids.apply_write() self.num_computed_tokens.apply_write() def remove_request(self, req_id: str) -> None: - self.extra_data.pop(req_id, None) req_idx = self.req_id_to_index.pop(req_id, None) if req_idx is None: # Request not found. return self.index_to_req_id.pop(req_idx, None) self.free_indices.append(req_idx) - - def make_lora_inputs( - self, - req_ids: list[str], - idx_mapping: np.ndarray, - num_scheduled_tokens: np.ndarray, - ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: - lora_ids = self.lora_ids[idx_mapping] - prompt_lora_mapping = tuple(lora_ids) - token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens)) - - active_lora_requests: set[LoRARequest] = set() - for req_id in req_ids: - lora_request = self.extra_data[req_id].lora_request - if lora_request is not None: - active_lora_requests.add(lora_request) - return prompt_lora_mapping, token_lora_mapping, active_lora_requests - - -@dataclass -class ExtraData: - lora_request: LoRARequest | None