Reduce the kernel overhead when num of active loras is smaller than max loras. Multiple cuda graphs are captured for each num of active-loras. (#32005)

Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
yugong333
2026-02-02 09:30:06 -08:00
committed by GitHub
parent 8b7346d5f1
commit ffe1fc7a28
15 changed files with 323 additions and 66 deletions

View File

@@ -181,6 +181,10 @@ def use_fused_moe_lora_kernel(
expert_ids = expert_ids.view(max_loras, -1)
sorted_token_ids = sorted_token_ids.view(max_loras, -1)
# num_active_loras is the number of active LoRAs
# (max_loras + 1 to include no-lora case)
num_active_loras = max_loras + 1
fused_moe_lora(
output,
hidden_states,
@@ -194,6 +198,7 @@ def use_fused_moe_lora_kernel(
max_lora_rank,
top_k_num,
lora_ids,
num_active_loras,
adapter_enabled,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
@@ -376,6 +381,10 @@ def use_fused_moe_lora_kernel_naive(
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)
# num_active_loras is the number of active LoRAs
# (max_loras + 1 to include no-lora case)
num_active_loras = max_loras + 1
fused_moe_lora(
output,
hidden_states,
@@ -389,6 +398,7 @@ def use_fused_moe_lora_kernel_naive(
max_lora_rank,
top_k_num,
lora_ids,
num_active_loras,
adapter_enabled,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],

View File

@@ -161,7 +161,7 @@ def check_lora_shrink_kernel(
data.inputs_tensor,
data.lora_weights,
out_tensor,
*lora_meta.meta_args(token_nums=token_nums),
*lora_meta.meta_args(token_nums=token_nums, specialize_active_lora=False),
scaling,
)
@@ -234,7 +234,7 @@ def check_lora_expand_kernel(
data.inputs_tensor,
data.lora_weights,
out_tensor,
*lora_meta.meta_args(token_nums=token_nums),
*lora_meta.meta_args(token_nums=token_nums, specialize_active_lora=False),
offset_start=0,
add_inputs=add_inputs,
)

View File

@@ -17,6 +17,7 @@ from vllm.config import (
SchedulerConfig,
VllmConfig,
)
from vllm.config.lora import LoRAConfig
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
@@ -47,6 +48,12 @@ def _create_vllm_config(
mock_config.speculative_config = None # No speculative decoding
if not lora_config:
mock_config.lora_config = None
else:
# Create a real LoRAConfig with specialize_active_lora enabled
mock_config.lora_config = LoRAConfig(
max_loras=4,
specialize_active_lora=True,
)
# Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1(
@@ -106,15 +113,19 @@ class TestCudagraphDispatcher:
)
# Verify the key is initialized correctly
# With LoRA specialization (max_loras=4, specialize_active_lora=True):
# - lora_cases = [0, 1, 2, 4, 5] (no-lora + powers of 2 up to 4 + max_loras+1)
# - capture_sizes = [1, 8]
# - Total keys = 2 sizes × 5 lora_cases = 10
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
4 if lora_config else 2
10 if lora_config else 2
)
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
4 if lora_config else 2
10 if lora_config else 2
)
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0

View File

@@ -60,6 +60,13 @@ class LoRAConfig:
of multimodal models will be enabled. This is an experimental feature and
currently only supports some MM models such as the Qwen VL series. The default
is False."""
specialize_active_lora: bool = False
"""Whether to construct lora kernel grid by the number of active LoRA adapters.
When set to True, separate cuda graphs will be captured for different counts
of active LoRAs (powers of 2 up to max_loras), which can improve performance
for variable LoRA usage patterns at the cost of increased startup time and
memory usage. Only takes effect when cudagraph_specialize_lora is True.
"""
def compute_hash(self) -> str:
"""

View File

@@ -485,6 +485,7 @@ class EngineArgs:
max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
specialize_active_lora: bool = LoRAConfig.specialize_active_lora
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
@@ -1026,6 +1027,9 @@ class EngineArgs:
"--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
)
lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
lora_group.add_argument(
"--specialize-active-lora", **lora_kwargs["specialize_active_lora"]
)
# Observability arguments
observability_kwargs = get_kwargs(ObservabilityConfig)
@@ -1657,6 +1661,7 @@ class EngineArgs:
fully_sharded_loras=self.fully_sharded_loras,
lora_dtype=self.lora_dtype,
enable_tower_connector_lora=self.enable_tower_connector_lora,
specialize_active_lora=self.specialize_active_lora,
max_cpu_loras=self.max_cpu_loras
if self.max_cpu_loras and self.max_cpu_loras > 0
else None,

View File

@@ -47,6 +47,14 @@ class BatchDescriptor(NamedTuple):
"""
Whether this batch has active LoRA adapters.
"""
num_active_loras: int = 0
"""
Number of distinct active LoRA adapters in this batch.
When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
are captured for each num_active_loras value. This allows kernels
(like fused_moe_lora) whose grid size depends on num_active_loras
to be properly captured.
"""
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
@@ -54,7 +62,11 @@ class BatchDescriptor(NamedTuple):
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
"""
return BatchDescriptor(
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
self.num_tokens,
num_reqs=None,
uniform=False,
has_lora=self.has_lora,
num_active_loras=self.num_active_loras,
)

View File

@@ -95,7 +95,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
def _adjust_kernel_inputs(
max_loras: int,
num_active_loras: int,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
):
@@ -109,7 +109,7 @@ def _adjust_kernel_inputs(
else:
stride_tl = sorted_token_ids.stride(0)
stride_el = expert_ids.stride(0)
grid_lora_dim = max_loras + 1
grid_lora_dim = num_active_loras
return grid_lora_dim, stride_tl, stride_el
@@ -354,6 +354,7 @@ def _fused_moe_lora_shrink(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
use_gdc: bool = False,
) -> None:
@@ -373,7 +374,7 @@ def _fused_moe_lora_shrink(
b_ptr = _get_ptr(lora_a_stacked, device)
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
w1_lora_a_stacked.shape[0], sorted_token_ids, expert_ids
num_active_loras, sorted_token_ids, expert_ids
)
grid = lambda META: (
split_k
@@ -457,6 +458,7 @@ def _fused_moe_lora_expand(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
offset: int = 0,
use_gdc: bool = False,
@@ -484,7 +486,7 @@ def _fused_moe_lora_expand(
}
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
w1_lora_b_stacked.shape[0], sorted_token_ids, expert_ids
num_active_loras, sorted_token_ids, expert_ids
)
grid = lambda META: (
@@ -557,6 +559,7 @@ def _fused_moe_lora(
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
num_active_loras: int,
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
@@ -648,6 +651,7 @@ def _fused_moe_lora(
shrink_num_warps,
shrink_num_stages,
shrink_split_k,
num_active_loras,
mul_routed_weight,
use_gdc=use_gdc,
)
@@ -695,6 +699,7 @@ def _fused_moe_lora(
expand_num_warps,
expand_num_stages,
expand_split_k,
num_active_loras,
mul_routed_weight,
offset,
use_gdc=use_gdc,
@@ -714,6 +719,7 @@ def _fused_moe_lora_fake(
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
num_active_loras: int,
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
@@ -730,6 +736,8 @@ def _fused_moe_lora_fake(
expand_num_stages: int,
expand_split_k: int,
mul_routed_weight: bool = False,
fully_sharded: bool = False,
offset: int = 0,
) -> None:
return
@@ -761,6 +769,7 @@ def _fused_moe_lora_shrink_fake(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
use_gdc: bool = False,
) -> None:
@@ -770,6 +779,7 @@ def _fused_moe_lora_shrink_fake(
def _fused_moe_lora_expand_fake(
output: torch.Tensor,
a_intermediate_cache1: torch.Tensor,
b_intermediate_cache1: torch.Tensor,
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
@@ -796,7 +806,9 @@ def _fused_moe_lora_expand_fake(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
offset: int = 0,
use_gdc: bool = False,
) -> None:
return

View File

@@ -138,6 +138,7 @@ def _lora_expand(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
@@ -234,10 +235,7 @@ def _lora_expand(
grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks simply exit.
MAX_LORAS,
num_active_loras,
)
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
@@ -291,6 +289,7 @@ def _lora_expand_fake(
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:

View File

@@ -4,7 +4,8 @@
LoRA kernels metadata preparation utilities.
"""
from dataclasses import dataclass
import bisect
from dataclasses import dataclass, field
import torch
@@ -28,9 +29,22 @@ class LoRAKernelMeta:
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
# Stored as a Python int to avoid GPU->CPU sync during forward pass
num_active_loras: int = 0
# Captured LoRA counts for cudagraph specialization (sorted list).
# When specialize_active_lora is enabled, num_active_loras is rounded up
# to the nearest value in this list to match cudagraph capture keys.
# Empty list means no specialization (use actual count).
captured_lora_counts: list[int] = field(default_factory=list)
@staticmethod
def make(
max_loras: int, max_num_tokens: int, device: torch.device | str
max_loras: int,
max_num_tokens: int,
device: torch.device | str,
captured_lora_counts: list[int] | None = None,
) -> "LoRAKernelMeta":
token_lora_mapping = torch.empty(
max_num_tokens, dtype=torch.int32, device=device
@@ -66,6 +80,9 @@ class LoRAKernelMeta:
num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu,
captured_lora_counts=sorted(captured_lora_counts)
if captured_lora_counts
else [],
)
def _reset(self):
@@ -73,6 +90,8 @@ class LoRAKernelMeta:
self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
self.num_active_loras = 0
self.captured_lora_counts = []
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
"""
@@ -118,6 +137,15 @@ class LoRAKernelMeta:
num_tokens_per_lora, non_blocking=True
)
self.num_active_loras = lora_ids.size(0)
# Round up num_active_loras to match cudagraph capture keys.
# This ensures the kernel grid dimension matches the captured graph.
if self.captured_lora_counts and self.num_active_loras > 0:
idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras)
if idx < len(self.captured_lora_counts):
self.num_active_loras = self.captured_lora_counts[idx]
# lora_token_start_loc
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
self.lora_token_start_loc[1 : 1 + lora_token_start_loc.size(0)].copy_(
@@ -125,7 +153,9 @@ class LoRAKernelMeta:
)
def meta_args(
self, token_nums: int
self,
token_nums: int,
specialize_active_lora: bool,
) -> tuple[
torch.Tensor,
torch.Tensor,
@@ -133,6 +163,7 @@ class LoRAKernelMeta:
torch.Tensor,
torch.Tensor,
torch.Tensor,
int,
]:
"""
This function returns the kernel metadata required for the current
@@ -144,6 +175,7 @@ class LoRAKernelMeta:
token_nums (int): Number of input tokens in the current forward
pass of the kernel.
"""
max_loras = self.active_lora_ids.size(0) - 1
return (
self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums],
@@ -151,4 +183,5 @@ class LoRAKernelMeta:
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
self.num_active_loras if specialize_active_lora else max_loras + 1,
)

View File

@@ -134,6 +134,7 @@ def _lora_shrink(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
scaling: float,
) -> None:
"""
@@ -214,10 +215,7 @@ def _lora_shrink(
grid = (
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks exit early.
MAX_LORAS,
num_active_loras,
)
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
@@ -269,6 +267,7 @@ def _lora_shrink_fake(
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
scaling: float,
) -> None:
return

View File

@@ -12,6 +12,7 @@ from typing import final
import torch
from vllm.lora.layers import LoRAMapping
from vllm.lora.utils import get_captured_lora_counts
from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils.math_utils import round_up
@@ -48,8 +49,16 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self.lora_config = kwargs["lora_config"]
self.max_loras = self.lora_config.max_loras
# Compute captured LoRA counts for cudagraph specialization.
captured_lora_counts = get_captured_lora_counts(
self.max_loras, self.lora_config.specialize_active_lora
)
self.token_mapping_meta = LoRAKernelMeta.make(
self.max_loras, max_num_batched_tokens, device=device
self.max_loras,
max_num_batched_tokens,
device=device,
captured_lora_counts=captured_lora_counts,
)
# When speculative decoding is enabled, max_num_samples is
@@ -57,7 +66,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
# This line can be optimized by replacing max_num_batched_tokens
# to max_batches * (num_speculative_decoding_tokens + 1).
self.prompt_mapping_meta = LoRAKernelMeta.make(
self.max_loras, max_num_batched_tokens, device=device
self.max_loras,
max_num_batched_tokens,
device=device,
captured_lora_counts=captured_lora_counts,
)
def update_metadata(
@@ -102,7 +114,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
x,
lora_a_stacked,
y,
*self.token_mapping_meta.meta_args(x.size(0)),
*self.token_mapping_meta.meta_args(
x.size(0), self.lora_config.specialize_active_lora
),
scale,
)
@@ -143,7 +157,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
x,
lora_b_stacked,
y,
*self.token_mapping_meta.meta_args(num_tokens),
*self.token_mapping_meta.meta_args(
num_tokens, self.lora_config.specialize_active_lora
),
offset_start=offset_start,
add_inputs=True,
)
@@ -175,7 +191,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
x.unsqueeze(dim=0),
(lora_b_stacked,),
y,
*self.token_mapping_meta.meta_args(x.size(0)),
*self.token_mapping_meta.meta_args(
x.size(0), self.lora_config.specialize_active_lora
),
offset_start=0,
add_inputs=add_inputs,
)
@@ -287,7 +305,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
x,
[lora_a_stacked],
buffer.unsqueeze(dim=0),
*self.prompt_mapping_meta.meta_args(x.size(0)),
*self.prompt_mapping_meta.meta_args(
x.size(0), self.lora_config.specialize_active_lora
),
scale,
)
@@ -295,7 +315,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
buffer.unsqueeze(dim=0),
[lora_b_stacked],
y,
*self.prompt_mapping_meta.meta_args(buffer.size(0)),
*self.prompt_mapping_meta.meta_args(
buffer.size(0), self.lora_config.specialize_active_lora
),
add_inputs=True,
)
y = y.view_as(y_org)
@@ -316,8 +338,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
(token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
num_tokens
(token_lora_mapping, _, _, _, lora_ids, _, _) = (
self.token_mapping_meta.meta_args(
num_tokens, self.lora_config.specialize_active_lora
)
)
if naive_block_assignment:
expert_ids = topk_ids.reshape(-1)
@@ -392,7 +416,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
_,
lora_ids,
_,
) = self.token_mapping_meta.meta_args(x.size(0))
num_active_loras,
) = self.token_mapping_meta.meta_args(
x.size(0), self.lora_config.specialize_active_lora
)
if token_lora_mapping is None:
token_lora_mapping = token_lora_mapping_meta
fused_moe_lora(
@@ -408,6 +435,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
max_lora_rank,
top_k_num,
lora_ids,
num_active_loras,
adapter_enabled,
shrink_config.get("BLOCK_SIZE_M", 64),
shrink_config.get("BLOCK_SIZE_N", 64),

View File

@@ -44,6 +44,25 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
def get_captured_lora_counts(max_loras: int, specialize: bool) -> list[int]:
"""
Returns num_active_loras values for cudagraph capture.
When specialize=True: powers of 2 up to max_loras, plus max_loras + 1.
When specialize=False: just [max_loras + 1].
This is the single source of truth for LoRA capture cases, used by both
CudagraphDispatcher and PunicaWrapperGPU.
"""
if not specialize:
return [max_loras + 1]
return [
n for n in range(1, max_loras + 2) if (n & (n - 1)) == 0 or n == max_loras + 1
]
_GLOBAL_LORA_ID = 0

View File

@@ -5,6 +5,7 @@ from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger
from vllm.lora.utils import get_captured_lora_counts
logger = init_logger(__name__)
@@ -57,6 +58,11 @@ class CudagraphDispatcher:
)
self.keys_initialized = False
self.specialize_lora_count = (
self.vllm_config.lora_config.specialize_active_lora
if self.vllm_config.lora_config is not None
else False
)
# Default cudagraph_mode to NONE until initialize_cudagraph_keys is called
self.cudagraph_mode = CUDAGraphMode.NONE
@@ -92,8 +98,33 @@ class CudagraphDispatcher:
"Use values from cudagraph_capture_sizes."
)
def _get_lora_cases(self) -> list[int]:
"""
Returns list of has_lora values for CUDA graph capture.
This is the single source of truth for LoRA capture cases.
"""
lora_config = self.vllm_config.lora_config
if lora_config is None:
# No LoRA configured - single case with no LoRA
return [0]
# LoRA is enabled - capture graphs based on cudagraph_specialize_lora
if self.compilation_config.cudagraph_specialize_lora:
captured_counts = get_captured_lora_counts(
lora_config.max_loras, self.specialize_lora_count
)
# Specialize: capture separate graphs for with and without LoRA
return [0] + captured_counts
else:
# No specialization: only capture graphs with LoRA active
return [lora_config.max_loras + 1]
def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool
self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
num_active_loras: int = 0,
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
@@ -111,6 +142,7 @@ class CudagraphDispatcher:
num_reqs=num_reqs,
uniform=uniform_decode,
has_lora=has_lora,
num_active_loras=num_active_loras,
)
def add_cudagraph_key(
@@ -135,26 +167,23 @@ class CudagraphDispatcher:
self._compute_bs_to_padded_graph_size()
# LoRA activation cases to specialize the cuda graphs on
if self.vllm_config.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
lora_cases = [True, False]
else:
lora_cases = [True]
else:
lora_cases = [False]
# Get LoRA cases to capture
lora_cases = self._get_lora_cases()
self.captured_lora_counts = [
lora_count for lora_count in lora_cases if lora_count
]
# Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs, has_lora in product(
for bs, num_active_loras in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
self._create_padded_batch_descriptor(
bs, False, has_lora
bs, False, num_active_loras > 0, num_active_loras
).relax_for_mixed_batch_cudagraphs(),
)
@@ -173,10 +202,14 @@ class CudagraphDispatcher:
for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
for bs, num_active_loras in product(
cudagraph_capture_sizes_for_decode, lora_cases
):
self.add_cudagraph_key(
CUDAGraphMode.FULL,
self._create_padded_batch_descriptor(bs, True, has_lora),
self._create_padded_batch_descriptor(
bs, True, num_active_loras > 0, num_active_loras
),
)
self.keys_initialized = True
@@ -187,6 +220,7 @@ class CudagraphDispatcher:
uniform_decode: bool = False,
has_lora: bool = False,
disable_full: bool = False,
num_active_loras: int = 0,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using piecewise only),
@@ -202,6 +236,7 @@ class CudagraphDispatcher:
disable_full: If True, skip FULL cudagraph checks and
return PIECEWISE or NONE only. (can be used for features like
cascade attention that are not supported by full cudagraphs)
num_active_loras: Number of distinct active LoRA adapters.
"""
if (
not self.keys_initialized
@@ -210,8 +245,24 @@ class CudagraphDispatcher:
):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
effective_num_active_loras = num_active_loras
if has_lora and num_active_loras > 0:
if self.specialize_lora_count:
# Find the smallest captured `num_active_loras` that is >= the current
# `num_active_loras`. This is because we only capture graphs for
# a subset of possible `num_active_loras` values (powers of 2).
import bisect
idx = bisect.bisect_left(self.captured_lora_counts, num_active_loras)
if idx < len(self.captured_lora_counts):
effective_num_active_loras = self.captured_lora_counts[idx]
else:
# When not specializing, graphs are captured only with max_loras + 1,
# so we must use max_loras + 1 for dispatch to find a matching graph.
effective_num_active_loras = self.vllm_config.lora_config.max_loras + 1
batch_desc = self._create_padded_batch_descriptor(
num_tokens, uniform_decode, has_lora
num_tokens, uniform_decode, has_lora, effective_num_active_loras
)
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()

View File

@@ -3082,6 +3082,7 @@ class GPUModelRunner(
# be improved in model runner v2)
force_uniform_decode: bool | None = None,
force_has_lora: bool | None = None,
force_num_active_loras: int | None = None,
num_encoder_reqs: int = 0,
) -> tuple[
CUDAGraphMode,
@@ -3103,11 +3104,13 @@ class GPUModelRunner(
self.model_config.is_encoder_decoder and num_encoder_reqs > 0
)
has_lora = (
len(self.input_batch.lora_id_to_lora_request) > 0
if force_has_lora is None
else force_has_lora
# Compute LoRA state for cudagraph dispatch
num_active_loras = (
force_num_active_loras
if force_num_active_loras is not None
else len(self.input_batch.lora_id_to_lora_request)
)
has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
dispatch_cudagraph = (
@@ -3116,6 +3119,7 @@ class GPUModelRunner(
has_lora=has_lora,
uniform_decode=uniform_decode,
disable_full=disable_full,
num_active_loras=num_active_loras,
)
if not force_eager
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
@@ -4606,8 +4610,8 @@ class GPUModelRunner(
is_profile: bool = False,
create_mixed_batch: bool = False,
remove_lora: bool = True,
activate_lora: bool = False,
is_graph_capturing: bool = False,
num_active_loras: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
@@ -4630,7 +4634,8 @@ class GPUModelRunner(
create_mixed_batch: If True, create a mixed batch with both decode
(1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run
activate_lora: If False, dummy_run is performed without LoRAs.
num_active_loras: Number of distinct active LoRAs to capture for.
LoRA is activated when num_active_loras > 0.
"""
mm_config = self.vllm_config.model_config.multimodal_config
if mm_config and mm_config.mm_encoder_only:
@@ -4712,7 +4717,10 @@ class GPUModelRunner(
# `force_has_lora` is used for cudagraph capture; because LoRA is
# activated later in the context manager, but we need to know the
# LoRA state when determining the batch descriptor for capture
force_has_lora=activate_lora,
force_has_lora=num_active_loras > 0,
# `force_num_active_loras` is used for cudagraph capture; because we
# need to capture graphs for specific num_active_loras counts
force_num_active_loras=num_active_loras,
)
)
@@ -4782,8 +4790,8 @@ class GPUModelRunner(
self.lora_config,
num_scheduled_tokens,
num_sampled_tokens,
activate_lora,
remove_lora,
num_active_loras,
):
# Make sure padding doesn't exceed max_num_tokens
assert num_tokens_padded <= self.max_num_tokens
@@ -4884,7 +4892,10 @@ class GPUModelRunner(
# lora cases when cudagraph_specialize_lora is enabled. This is a
# short term mitigation for issue mentioned in
# https://github.com/vllm-project/vllm/issues/28334
if self.compilation_config.cudagraph_specialize_lora and activate_lora:
if (
self.compilation_config.cudagraph_specialize_lora
and num_active_loras > 0
):
use_cudagraphs = False
self.drafter.dummy_run(
@@ -5259,7 +5270,7 @@ class GPUModelRunner(
# We skip EPLB here since we don't want to record dummy metrics
for batch_desc in batch_descriptors:
num_tokens = batch_desc.num_tokens
activate_lora = batch_desc.has_lora
num_active_loras = batch_desc.num_active_loras
# We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens
@@ -5286,7 +5297,7 @@ class GPUModelRunner(
num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
allow_microbatching=allow_microbatching,
activate_lora=activate_lora,
num_active_loras=num_active_loras,
)
# Capture run
@@ -5294,7 +5305,7 @@ class GPUModelRunner(
num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
allow_microbatching=allow_microbatching,
activate_lora=activate_lora,
num_active_loras=num_active_loras,
is_graph_capturing=True,
)
self.maybe_remove_all_loras(self.lora_config)

View File

@@ -133,11 +133,23 @@ class LoRAModelRunnerMixin:
num_scheduled_tokens: np.ndarray,
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
num_sampled_tokens: np.ndarray | None = None,
activate_lora: bool = True,
num_active_loras: int = 0,
):
"""
Context manager to select dummy LoRAs for capture/warmup.
Args:
lora_config: LoRA configuration, or None if LoRA is disabled.
num_scheduled_tokens: Array of scheduled token counts per request.
num_sampled_tokens: Array of sampled token counts per request.
num_active_loras: Number of distinct active LoRAs to use.
- 0: No LoRA active (set up zero mappings).
- >0: Use exactly this many distinct LoRAs.
"""
if num_sampled_tokens is None:
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
# Skip LoRA setup entirely only if no LoRA config
if lora_config is None:
yield
else:
@@ -145,15 +157,52 @@ class LoRAModelRunnerMixin:
assert self.lora_manager is not None, "LoRA is not enabled"
num_reqs = len(num_scheduled_tokens)
num_loras = lora_config.max_loras
max_loras = lora_config.max_loras
# Determine how many distinct LoRAs to use and whether to include
# no-LoRA tokens (-1 entries).
# When num_active_loras > max_loras (e.g., max_loras + 1), we need
# to include -1 entries to simulate batches with both LoRA and
# no-LoRA tokens. This ensures prepare_tensors computes the correct
# num_active_loras that matches the cudagraph capture key.
if num_active_loras == 0:
# No LoRA active - use 0 mappings like the original code
effective_num_loras = 0
include_no_lora = False
elif num_active_loras > max_loras:
# num_active_loras > max_loras means we want max_loras adapters
# PLUS no-LoRA tokens (-1). This is the max_loras + 1 case.
effective_num_loras = max_loras
include_no_lora = True
else:
# Specific number of active LoRAs requested
effective_num_loras = min(num_active_loras, max_loras)
include_no_lora = False
# Make prompt lora mapping
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
if activate_lora:
prompt_lora_mapping = (
np.arange(num_reqs, dtype=np.int32) % num_loras
) + 1
# LoRA IDs are 1-indexed (1 to max_loras) as required by LoRARequest.
# convert_mapping() will convert these to 0-indexed slot indices.
if effective_num_loras > 0:
if include_no_lora:
# Include -1 (no-LoRA) entries by cycling through
# -1, 1, 2, ..., effective_num_loras
# This ensures prepare_tensors sees both LoRA and no-LoRA
# tokens, computing num_active_loras = effective_num_loras+1
cycle_values = np.array(
list(range(1, effective_num_loras + 1)),
dtype=np.int32,
)
prompt_lora_mapping = cycle_values[
np.arange(num_reqs, dtype=np.int32) % len(cycle_values)
]
else:
# Use 1 to effective_num_loras (1-indexed lora IDs)
prompt_lora_mapping = (
np.arange(num_reqs, dtype=np.int32) % effective_num_loras
) + 1
else:
# No LoRA active - use 0 for all tokens (original behavior)
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
# Make sample lora mapping
@@ -162,14 +211,14 @@ class LoRAModelRunnerMixin:
# Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
# Make dummy lora requests
# Make dummy lora requests (only for the active LoRAs)
lora_requests: set[LoRARequest] = {
LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path",
)
for lora_id in range(1, num_loras + 1)
for lora_id in range(1, effective_num_loras + 1)
}
self._set_active_loras(
@@ -187,10 +236,21 @@ class LoRAModelRunnerMixin:
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray,
activate_lora: bool = True,
remove_lora: bool = True,
num_active_loras: int = 0,
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
):
"""
Context manager for dummy runs with LoRA.
Args:
lora_config: LoRA configuration.
num_scheduled_tokens: Array of scheduled token counts per request.
num_sampled_tokens: Array of sampled token counts per request.
remove_lora: Whether to remove LoRAs after the context exits.
num_active_loras: Number of distinct active LoRAs to use.
LoRA is activated when num_active_loras > 0.
"""
with (
self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras(
@@ -198,7 +258,7 @@ class LoRAModelRunnerMixin:
num_scheduled_tokens,
mapping_type,
num_sampled_tokens,
activate_lora,
num_active_loras,
),
):
yield