[Model Runner V2] Add LoRAState to consolidate lora logic (#33062)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-01-25 22:21:12 -08:00
committed by GitHub
parent 254db42ede
commit a9b53dd435
3 changed files with 53 additions and 42 deletions

View File

@@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
from vllm.lora.request import LoRARequest
NO_LORA_ID = 0
class LoraState:
def __init__(self, max_num_reqs: int):
self.lora_ids = np.zeros(max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID)
# req_id -> lora_request
self.lora_requests: dict[str, LoRARequest] = {}
def add_request(
self,
req_id: str,
req_index: int,
lora_request: LoRARequest | None,
) -> None:
if lora_request is not None:
self.lora_requests[req_id] = lora_request
self.lora_ids[req_index] = lora_request.lora_int_id
else:
self.lora_ids[req_index] = NO_LORA_ID
def remove_request(self, req_id: str) -> None:
self.lora_requests.pop(req_id, None)
def make_lora_inputs(
self,
req_ids: list[str],
idx_mapping: np.ndarray,
num_scheduled_tokens: np.ndarray,
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
lora_ids = self.lora_ids[idx_mapping]
prompt_lora_mapping = tuple(lora_ids)
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set()
for req_id in req_ids:
lora_request = self.lora_requests.get(req_id, None)
if lora_request is not None:
active_lora_requests.add(lora_request)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests

View File

@@ -51,6 +51,7 @@ from vllm.v1.worker.gpu.kv_connector import (
KVConnector,
get_kv_connector,
)
from vllm.v1.worker.gpu.lora_utils import LoraState
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.sample.output import SamplerOutput
@@ -168,6 +169,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
vocab_size=self.vocab_size,
)
# LoRA-related workers.
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
# Buffers for CPU-to-GPU copies.
self.tmp_idx_mapping = UvaBufferPool(self.max_num_reqs, torch.int32)
@@ -426,6 +429,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.supports_mm_inputs:
self.encoder_runner.remove_request(req_id)
self.prompt_logprobs_worker.remove_request(req_id)
self.lora_state.remove_request(req_id)
def free_states(self, scheduler_output: SchedulerOutput) -> None:
if self.supports_mm_inputs:
@@ -444,7 +448,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
prompt_len=prompt_len,
prefill_token_ids=new_req_data.prefill_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
lora_request=new_req_data.lora_request,
)
req_index = self.req_states.req_id_to_index[req_id]
@@ -469,6 +472,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.prompt_logprobs_worker.add_request(
req_id, req_index, new_req_data.sampling_params
)
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes()
@@ -841,7 +845,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
if self.lora_config:
# Activate LoRA adapters.
lora_inputs = self.req_states.make_lora_inputs(
lora_inputs = self.lora_state.make_lora_inputs(
input_batch.req_ids,
input_batch.idx_mapping_np,
input_batch.num_scheduled_tokens,

View File

@@ -1,15 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
NO_LORA_ID = 0
class RequestState:
def __init__(
@@ -31,7 +26,6 @@ class RequestState:
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_reqs))
self.extra_data: dict[str, ExtraData] = {}
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
@@ -70,10 +64,6 @@ class RequestState:
self.max_num_reqs, dtype=torch.int32, device=device
)
# LoRA.
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@@ -84,13 +74,11 @@ class RequestState:
prompt_len: int,
prefill_token_ids: list[int],
num_computed_tokens: int,
lora_request: LoRARequest | None,
) -> None:
assert len(self.free_indices) > 0, "No free indices"
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
self.extra_data[req_id] = ExtraData(lora_request)
self.prompt_len[req_idx] = prompt_len
prefill_len = len(prefill_token_ids)
@@ -102,43 +90,15 @@ class RequestState:
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
if lora_request is not None:
self.lora_ids[req_idx] = lora_request.lora_int_id
else:
self.lora_ids[req_idx] = NO_LORA_ID
def apply_staged_writes(self) -> None:
self.prefill_len.copy_to_uva()
self.prefill_token_ids.apply_write()
self.num_computed_tokens.apply_write()
def remove_request(self, req_id: str) -> None:
self.extra_data.pop(req_id, None)
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
def make_lora_inputs(
self,
req_ids: list[str],
idx_mapping: np.ndarray,
num_scheduled_tokens: np.ndarray,
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
lora_ids = self.lora_ids[idx_mapping]
prompt_lora_mapping = tuple(lora_ids)
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set()
for req_id in req_ids:
lora_request = self.extra_data[req_id].lora_request
if lora_request is not None:
active_lora_requests.add(lora_request)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
@dataclass
class ExtraData:
lora_request: LoRARequest | None