Reduce the kernel overhead when num of active loras is smaller than max loras. Multiple cuda graphs are captured for each num of active-loras. (#32005)
Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
@@ -181,6 +181,10 @@ def use_fused_moe_lora_kernel(
|
||||
expert_ids = expert_ids.view(max_loras, -1)
|
||||
sorted_token_ids = sorted_token_ids.view(max_loras, -1)
|
||||
|
||||
# num_active_loras is the number of active LoRAs
|
||||
# (max_loras + 1 to include no-lora case)
|
||||
num_active_loras = max_loras + 1
|
||||
|
||||
fused_moe_lora(
|
||||
output,
|
||||
hidden_states,
|
||||
@@ -194,6 +198,7 @@ def use_fused_moe_lora_kernel(
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
num_active_loras,
|
||||
adapter_enabled,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
@@ -376,6 +381,10 @@ def use_fused_moe_lora_kernel_naive(
|
||||
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
|
||||
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)
|
||||
|
||||
# num_active_loras is the number of active LoRAs
|
||||
# (max_loras + 1 to include no-lora case)
|
||||
num_active_loras = max_loras + 1
|
||||
|
||||
fused_moe_lora(
|
||||
output,
|
||||
hidden_states,
|
||||
@@ -389,6 +398,7 @@ def use_fused_moe_lora_kernel_naive(
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
num_active_loras,
|
||||
adapter_enabled,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
|
||||
@@ -161,7 +161,7 @@ def check_lora_shrink_kernel(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
out_tensor,
|
||||
*lora_meta.meta_args(token_nums=token_nums),
|
||||
*lora_meta.meta_args(token_nums=token_nums, specialize_active_lora=False),
|
||||
scaling,
|
||||
)
|
||||
|
||||
@@ -234,7 +234,7 @@ def check_lora_expand_kernel(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
out_tensor,
|
||||
*lora_meta.meta_args(token_nums=token_nums),
|
||||
*lora_meta.meta_args(token_nums=token_nums, specialize_active_lora=False),
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.config import (
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
@@ -47,6 +48,12 @@ def _create_vllm_config(
|
||||
mock_config.speculative_config = None # No speculative decoding
|
||||
if not lora_config:
|
||||
mock_config.lora_config = None
|
||||
else:
|
||||
# Create a real LoRAConfig with specialize_active_lora enabled
|
||||
mock_config.lora_config = LoRAConfig(
|
||||
max_loras=4,
|
||||
specialize_active_lora=True,
|
||||
)
|
||||
# Mimic the behavior of VllmConfig.__post_init__()
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
compilation_config.set_splitting_ops_for_v1(
|
||||
@@ -106,15 +113,19 @@ class TestCudagraphDispatcher:
|
||||
)
|
||||
|
||||
# Verify the key is initialized correctly
|
||||
# With LoRA specialization (max_loras=4, specialize_active_lora=True):
|
||||
# - lora_cases = [0, 1, 2, 4, 5] (no-lora + powers of 2 up to 4 + max_loras+1)
|
||||
# - capture_sizes = [1, 8]
|
||||
# - Total keys = 2 sizes × 5 lora_cases = 10
|
||||
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
|
||||
4 if lora_config else 2
|
||||
10 if lora_config else 2
|
||||
)
|
||||
else:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
|
||||
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
|
||||
4 if lora_config else 2
|
||||
10 if lora_config else 2
|
||||
)
|
||||
else:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
|
||||
|
||||
@@ -60,6 +60,13 @@ class LoRAConfig:
|
||||
of multimodal models will be enabled. This is an experimental feature and
|
||||
currently only supports some MM models such as the Qwen VL series. The default
|
||||
is False."""
|
||||
specialize_active_lora: bool = False
|
||||
"""Whether to construct lora kernel grid by the number of active LoRA adapters.
|
||||
When set to True, separate cuda graphs will be captured for different counts
|
||||
of active LoRAs (powers of 2 up to max_loras), which can improve performance
|
||||
for variable LoRA usage patterns at the cost of increased startup time and
|
||||
memory usage. Only takes effect when cudagraph_specialize_lora is True.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
|
||||
@@ -485,6 +485,7 @@ class EngineArgs:
|
||||
max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
|
||||
lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
|
||||
enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
|
||||
specialize_active_lora: bool = LoRAConfig.specialize_active_lora
|
||||
|
||||
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
|
||||
num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
|
||||
@@ -1026,6 +1027,9 @@ class EngineArgs:
|
||||
"--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
|
||||
)
|
||||
lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
|
||||
lora_group.add_argument(
|
||||
"--specialize-active-lora", **lora_kwargs["specialize_active_lora"]
|
||||
)
|
||||
|
||||
# Observability arguments
|
||||
observability_kwargs = get_kwargs(ObservabilityConfig)
|
||||
@@ -1657,6 +1661,7 @@ class EngineArgs:
|
||||
fully_sharded_loras=self.fully_sharded_loras,
|
||||
lora_dtype=self.lora_dtype,
|
||||
enable_tower_connector_lora=self.enable_tower_connector_lora,
|
||||
specialize_active_lora=self.specialize_active_lora,
|
||||
max_cpu_loras=self.max_cpu_loras
|
||||
if self.max_cpu_loras and self.max_cpu_loras > 0
|
||||
else None,
|
||||
|
||||
@@ -47,6 +47,14 @@ class BatchDescriptor(NamedTuple):
|
||||
"""
|
||||
Whether this batch has active LoRA adapters.
|
||||
"""
|
||||
num_active_loras: int = 0
|
||||
"""
|
||||
Number of distinct active LoRA adapters in this batch.
|
||||
When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
|
||||
are captured for each num_active_loras value. This allows kernels
|
||||
(like fused_moe_lora) whose grid size depends on num_active_loras
|
||||
to be properly captured.
|
||||
"""
|
||||
|
||||
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
|
||||
"""
|
||||
@@ -54,7 +62,11 @@ class BatchDescriptor(NamedTuple):
|
||||
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
|
||||
"""
|
||||
return BatchDescriptor(
|
||||
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
|
||||
self.num_tokens,
|
||||
num_reqs=None,
|
||||
uniform=False,
|
||||
has_lora=self.has_lora,
|
||||
num_active_loras=self.num_active_loras,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
|
||||
|
||||
|
||||
def _adjust_kernel_inputs(
|
||||
max_loras: int,
|
||||
num_active_loras: int,
|
||||
sorted_token_ids: torch.Tensor | None,
|
||||
expert_ids: torch.Tensor,
|
||||
):
|
||||
@@ -109,7 +109,7 @@ def _adjust_kernel_inputs(
|
||||
else:
|
||||
stride_tl = sorted_token_ids.stride(0)
|
||||
stride_el = expert_ids.stride(0)
|
||||
grid_lora_dim = max_loras + 1
|
||||
grid_lora_dim = num_active_loras
|
||||
return grid_lora_dim, stride_tl, stride_el
|
||||
|
||||
|
||||
@@ -354,6 +354,7 @@ def _fused_moe_lora_shrink(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
mul_routed_weight: bool = False,
|
||||
use_gdc: bool = False,
|
||||
) -> None:
|
||||
@@ -373,7 +374,7 @@ def _fused_moe_lora_shrink(
|
||||
b_ptr = _get_ptr(lora_a_stacked, device)
|
||||
|
||||
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
|
||||
w1_lora_a_stacked.shape[0], sorted_token_ids, expert_ids
|
||||
num_active_loras, sorted_token_ids, expert_ids
|
||||
)
|
||||
grid = lambda META: (
|
||||
split_k
|
||||
@@ -457,6 +458,7 @@ def _fused_moe_lora_expand(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
mul_routed_weight: bool = False,
|
||||
offset: int = 0,
|
||||
use_gdc: bool = False,
|
||||
@@ -484,7 +486,7 @@ def _fused_moe_lora_expand(
|
||||
}
|
||||
|
||||
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
|
||||
w1_lora_b_stacked.shape[0], sorted_token_ids, expert_ids
|
||||
num_active_loras, sorted_token_ids, expert_ids
|
||||
)
|
||||
|
||||
grid = lambda META: (
|
||||
@@ -557,6 +559,7 @@ def _fused_moe_lora(
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
adapter_enabled: torch.Tensor,
|
||||
shrink_block_size_m: int,
|
||||
shrink_block_size_n: int,
|
||||
@@ -648,6 +651,7 @@ def _fused_moe_lora(
|
||||
shrink_num_warps,
|
||||
shrink_num_stages,
|
||||
shrink_split_k,
|
||||
num_active_loras,
|
||||
mul_routed_weight,
|
||||
use_gdc=use_gdc,
|
||||
)
|
||||
@@ -695,6 +699,7 @@ def _fused_moe_lora(
|
||||
expand_num_warps,
|
||||
expand_num_stages,
|
||||
expand_split_k,
|
||||
num_active_loras,
|
||||
mul_routed_weight,
|
||||
offset,
|
||||
use_gdc=use_gdc,
|
||||
@@ -714,6 +719,7 @@ def _fused_moe_lora_fake(
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
adapter_enabled: torch.Tensor,
|
||||
shrink_block_size_m: int,
|
||||
shrink_block_size_n: int,
|
||||
@@ -730,6 +736,8 @@ def _fused_moe_lora_fake(
|
||||
expand_num_stages: int,
|
||||
expand_split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
fully_sharded: bool = False,
|
||||
offset: int = 0,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@@ -761,6 +769,7 @@ def _fused_moe_lora_shrink_fake(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
mul_routed_weight: bool = False,
|
||||
use_gdc: bool = False,
|
||||
) -> None:
|
||||
@@ -770,6 +779,7 @@ def _fused_moe_lora_shrink_fake(
|
||||
def _fused_moe_lora_expand_fake(
|
||||
output: torch.Tensor,
|
||||
a_intermediate_cache1: torch.Tensor,
|
||||
b_intermediate_cache1: torch.Tensor,
|
||||
lora_b_stacked: list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor | None,
|
||||
@@ -796,7 +806,9 @@ def _fused_moe_lora_expand_fake(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
mul_routed_weight: bool = False,
|
||||
offset: int = 0,
|
||||
use_gdc: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@@ -138,6 +138,7 @@ def _lora_expand(
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
@@ -234,10 +235,7 @@ def _lora_expand(
|
||||
grid = (
|
||||
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
# Each LoRA receives its own set of thread blocks for output
|
||||
# computation. If some LoRA doesn't have any tokens to process, its
|
||||
# thread blocks simply exit.
|
||||
MAX_LORAS,
|
||||
num_active_loras,
|
||||
)
|
||||
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
|
||||
# making PDL invalid and affecting the kernel performance.
|
||||
@@ -291,6 +289,7 @@ def _lora_expand_fake(
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
LoRA kernels metadata preparation utilities.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import bisect
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
|
||||
@@ -28,9 +29,22 @@ class LoRAKernelMeta:
|
||||
# to early exit from inside the lora_expand / lora_shrink torch operation.
|
||||
no_lora_flag_cpu: torch.Tensor
|
||||
|
||||
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
|
||||
# Stored as a Python int to avoid GPU->CPU sync during forward pass
|
||||
num_active_loras: int = 0
|
||||
|
||||
# Captured LoRA counts for cudagraph specialization (sorted list).
|
||||
# When specialize_active_lora is enabled, num_active_loras is rounded up
|
||||
# to the nearest value in this list to match cudagraph capture keys.
|
||||
# Empty list means no specialization (use actual count).
|
||||
captured_lora_counts: list[int] = field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
max_loras: int, max_num_tokens: int, device: torch.device | str
|
||||
max_loras: int,
|
||||
max_num_tokens: int,
|
||||
device: torch.device | str,
|
||||
captured_lora_counts: list[int] | None = None,
|
||||
) -> "LoRAKernelMeta":
|
||||
token_lora_mapping = torch.empty(
|
||||
max_num_tokens, dtype=torch.int32, device=device
|
||||
@@ -66,6 +80,9 @@ class LoRAKernelMeta:
|
||||
num_tokens_per_lora=num_tokens_per_lora,
|
||||
lora_token_start_loc=lora_token_start_loc,
|
||||
no_lora_flag_cpu=no_lora_flag_cpu,
|
||||
captured_lora_counts=sorted(captured_lora_counts)
|
||||
if captured_lora_counts
|
||||
else [],
|
||||
)
|
||||
|
||||
def _reset(self):
|
||||
@@ -73,6 +90,8 @@ class LoRAKernelMeta:
|
||||
self.num_tokens_per_lora.fill_(0)
|
||||
self.lora_token_start_loc.fill_(0)
|
||||
self.no_lora_flag_cpu.fill_(False)
|
||||
self.num_active_loras = 0
|
||||
self.captured_lora_counts = []
|
||||
|
||||
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
|
||||
"""
|
||||
@@ -118,6 +137,15 @@ class LoRAKernelMeta:
|
||||
num_tokens_per_lora, non_blocking=True
|
||||
)
|
||||
|
||||
self.num_active_loras = lora_ids.size(0)
|
||||
|
||||
# Round up num_active_loras to match cudagraph capture keys.
|
||||
# This ensures the kernel grid dimension matches the captured graph.
|
||||
if self.captured_lora_counts and self.num_active_loras > 0:
|
||||
idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras)
|
||||
if idx < len(self.captured_lora_counts):
|
||||
self.num_active_loras = self.captured_lora_counts[idx]
|
||||
|
||||
# lora_token_start_loc
|
||||
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
|
||||
self.lora_token_start_loc[1 : 1 + lora_token_start_loc.size(0)].copy_(
|
||||
@@ -125,7 +153,9 @@ class LoRAKernelMeta:
|
||||
)
|
||||
|
||||
def meta_args(
|
||||
self, token_nums: int
|
||||
self,
|
||||
token_nums: int,
|
||||
specialize_active_lora: bool,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
@@ -133,6 +163,7 @@ class LoRAKernelMeta:
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
int,
|
||||
]:
|
||||
"""
|
||||
This function returns the kernel metadata required for the current
|
||||
@@ -144,6 +175,7 @@ class LoRAKernelMeta:
|
||||
token_nums (int): Number of input tokens in the current forward
|
||||
pass of the kernel.
|
||||
"""
|
||||
max_loras = self.active_lora_ids.size(0) - 1
|
||||
return (
|
||||
self.token_lora_mapping[:token_nums],
|
||||
self.token_indices_sorted_by_lora_ids[:token_nums],
|
||||
@@ -151,4 +183,5 @@ class LoRAKernelMeta:
|
||||
self.lora_token_start_loc,
|
||||
self.active_lora_ids,
|
||||
self.no_lora_flag_cpu,
|
||||
self.num_active_loras if specialize_active_lora else max_loras + 1,
|
||||
)
|
||||
|
||||
@@ -134,6 +134,7 @@ def _lora_shrink(
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -214,10 +215,7 @@ def _lora_shrink(
|
||||
grid = (
|
||||
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
# Each LoRA receives its own set of thread blocks for output
|
||||
# computation. If some LoRA doesn't have any tokens to process, its
|
||||
# thread blocks exit early.
|
||||
MAX_LORAS,
|
||||
num_active_loras,
|
||||
)
|
||||
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
|
||||
# making PDL invalid and affecting the kernel performance.
|
||||
@@ -269,6 +267,7 @@ def _lora_shrink_fake(
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import final
|
||||
import torch
|
||||
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.utils import get_captured_lora_counts
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
@@ -48,8 +49,16 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
self.lora_config = kwargs["lora_config"]
|
||||
self.max_loras = self.lora_config.max_loras
|
||||
|
||||
# Compute captured LoRA counts for cudagraph specialization.
|
||||
captured_lora_counts = get_captured_lora_counts(
|
||||
self.max_loras, self.lora_config.specialize_active_lora
|
||||
)
|
||||
|
||||
self.token_mapping_meta = LoRAKernelMeta.make(
|
||||
self.max_loras, max_num_batched_tokens, device=device
|
||||
self.max_loras,
|
||||
max_num_batched_tokens,
|
||||
device=device,
|
||||
captured_lora_counts=captured_lora_counts,
|
||||
)
|
||||
|
||||
# When speculative decoding is enabled, max_num_samples is
|
||||
@@ -57,7 +66,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
# This line can be optimized by replacing max_num_batched_tokens
|
||||
# to max_batches * (num_speculative_decoding_tokens + 1).
|
||||
self.prompt_mapping_meta = LoRAKernelMeta.make(
|
||||
self.max_loras, max_num_batched_tokens, device=device
|
||||
self.max_loras,
|
||||
max_num_batched_tokens,
|
||||
device=device,
|
||||
captured_lora_counts=captured_lora_counts,
|
||||
)
|
||||
|
||||
def update_metadata(
|
||||
@@ -102,7 +114,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
x,
|
||||
lora_a_stacked,
|
||||
y,
|
||||
*self.token_mapping_meta.meta_args(x.size(0)),
|
||||
*self.token_mapping_meta.meta_args(
|
||||
x.size(0), self.lora_config.specialize_active_lora
|
||||
),
|
||||
scale,
|
||||
)
|
||||
|
||||
@@ -143,7 +157,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
x,
|
||||
lora_b_stacked,
|
||||
y,
|
||||
*self.token_mapping_meta.meta_args(num_tokens),
|
||||
*self.token_mapping_meta.meta_args(
|
||||
num_tokens, self.lora_config.specialize_active_lora
|
||||
),
|
||||
offset_start=offset_start,
|
||||
add_inputs=True,
|
||||
)
|
||||
@@ -175,7 +191,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
x.unsqueeze(dim=0),
|
||||
(lora_b_stacked,),
|
||||
y,
|
||||
*self.token_mapping_meta.meta_args(x.size(0)),
|
||||
*self.token_mapping_meta.meta_args(
|
||||
x.size(0), self.lora_config.specialize_active_lora
|
||||
),
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
@@ -287,7 +305,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
x,
|
||||
[lora_a_stacked],
|
||||
buffer.unsqueeze(dim=0),
|
||||
*self.prompt_mapping_meta.meta_args(x.size(0)),
|
||||
*self.prompt_mapping_meta.meta_args(
|
||||
x.size(0), self.lora_config.specialize_active_lora
|
||||
),
|
||||
scale,
|
||||
)
|
||||
|
||||
@@ -295,7 +315,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
buffer.unsqueeze(dim=0),
|
||||
[lora_b_stacked],
|
||||
y,
|
||||
*self.prompt_mapping_meta.meta_args(buffer.size(0)),
|
||||
*self.prompt_mapping_meta.meta_args(
|
||||
buffer.size(0), self.lora_config.specialize_active_lora
|
||||
),
|
||||
add_inputs=True,
|
||||
)
|
||||
y = y.view_as(y_org)
|
||||
@@ -316,8 +338,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
Aligns tokens and experts into block-sized chunks for LoRA-based
|
||||
mixture-of-experts (MoE) execution.
|
||||
"""
|
||||
(token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
|
||||
num_tokens
|
||||
(token_lora_mapping, _, _, _, lora_ids, _, _) = (
|
||||
self.token_mapping_meta.meta_args(
|
||||
num_tokens, self.lora_config.specialize_active_lora
|
||||
)
|
||||
)
|
||||
if naive_block_assignment:
|
||||
expert_ids = topk_ids.reshape(-1)
|
||||
@@ -392,7 +416,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
_,
|
||||
lora_ids,
|
||||
_,
|
||||
) = self.token_mapping_meta.meta_args(x.size(0))
|
||||
num_active_loras,
|
||||
) = self.token_mapping_meta.meta_args(
|
||||
x.size(0), self.lora_config.specialize_active_lora
|
||||
)
|
||||
if token_lora_mapping is None:
|
||||
token_lora_mapping = token_lora_mapping_meta
|
||||
fused_moe_lora(
|
||||
@@ -408,6 +435,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
num_active_loras,
|
||||
adapter_enabled,
|
||||
shrink_config.get("BLOCK_SIZE_M", 64),
|
||||
shrink_config.get("BLOCK_SIZE_N", 64),
|
||||
|
||||
@@ -44,6 +44,25 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_captured_lora_counts(max_loras: int, specialize: bool) -> list[int]:
|
||||
"""
|
||||
Returns num_active_loras values for cudagraph capture.
|
||||
|
||||
When specialize=True: powers of 2 up to max_loras, plus max_loras + 1.
|
||||
When specialize=False: just [max_loras + 1].
|
||||
|
||||
This is the single source of truth for LoRA capture cases, used by both
|
||||
CudagraphDispatcher and PunicaWrapperGPU.
|
||||
"""
|
||||
if not specialize:
|
||||
return [max_loras + 1]
|
||||
|
||||
return [
|
||||
n for n in range(1, max_loras + 2) if (n & (n - 1)) == 0 or n == max_loras + 1
|
||||
]
|
||||
|
||||
|
||||
_GLOBAL_LORA_ID = 0
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from itertools import product
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.forward_context import BatchDescriptor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.utils import get_captured_lora_counts
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -57,6 +58,11 @@ class CudagraphDispatcher:
|
||||
)
|
||||
|
||||
self.keys_initialized = False
|
||||
self.specialize_lora_count = (
|
||||
self.vllm_config.lora_config.specialize_active_lora
|
||||
if self.vllm_config.lora_config is not None
|
||||
else False
|
||||
)
|
||||
# Default cudagraph_mode to NONE until initialize_cudagraph_keys is called
|
||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
@@ -92,8 +98,33 @@ class CudagraphDispatcher:
|
||||
"Use values from cudagraph_capture_sizes."
|
||||
)
|
||||
|
||||
def _get_lora_cases(self) -> list[int]:
|
||||
"""
|
||||
Returns list of has_lora values for CUDA graph capture.
|
||||
This is the single source of truth for LoRA capture cases.
|
||||
"""
|
||||
lora_config = self.vllm_config.lora_config
|
||||
if lora_config is None:
|
||||
# No LoRA configured - single case with no LoRA
|
||||
return [0]
|
||||
|
||||
# LoRA is enabled - capture graphs based on cudagraph_specialize_lora
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
captured_counts = get_captured_lora_counts(
|
||||
lora_config.max_loras, self.specialize_lora_count
|
||||
)
|
||||
# Specialize: capture separate graphs for with and without LoRA
|
||||
return [0] + captured_counts
|
||||
else:
|
||||
# No specialization: only capture graphs with LoRA active
|
||||
return [lora_config.max_loras + 1]
|
||||
|
||||
def _create_padded_batch_descriptor(
|
||||
self, num_tokens: int, uniform_decode: bool, has_lora: bool
|
||||
self,
|
||||
num_tokens: int,
|
||||
uniform_decode: bool,
|
||||
has_lora: bool,
|
||||
num_active_loras: int = 0,
|
||||
) -> BatchDescriptor:
|
||||
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
|
||||
uniform_decode_query_len = self.uniform_decode_query_len
|
||||
@@ -111,6 +142,7 @@ class CudagraphDispatcher:
|
||||
num_reqs=num_reqs,
|
||||
uniform=uniform_decode,
|
||||
has_lora=has_lora,
|
||||
num_active_loras=num_active_loras,
|
||||
)
|
||||
|
||||
def add_cudagraph_key(
|
||||
@@ -135,26 +167,23 @@ class CudagraphDispatcher:
|
||||
|
||||
self._compute_bs_to_padded_graph_size()
|
||||
|
||||
# LoRA activation cases to specialize the cuda graphs on
|
||||
if self.vllm_config.lora_config:
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases = [True, False]
|
||||
else:
|
||||
lora_cases = [True]
|
||||
else:
|
||||
lora_cases = [False]
|
||||
# Get LoRA cases to capture
|
||||
lora_cases = self._get_lora_cases()
|
||||
self.captured_lora_counts = [
|
||||
lora_count for lora_count in lora_cases if lora_count
|
||||
]
|
||||
|
||||
# Note: we create all valid keys for cudagraph here but do not
|
||||
# guarantee all keys would be used. For example, if we allow lazy
|
||||
# capturing in future PR, some keys may never be triggered.
|
||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
for bs, has_lora in product(
|
||||
for bs, num_active_loras in product(
|
||||
self.compilation_config.cudagraph_capture_sizes, lora_cases
|
||||
):
|
||||
self.add_cudagraph_key(
|
||||
cudagraph_mode.mixed_mode(),
|
||||
self._create_padded_batch_descriptor(
|
||||
bs, False, has_lora
|
||||
bs, False, num_active_loras > 0, num_active_loras
|
||||
).relax_for_mixed_batch_cudagraphs(),
|
||||
)
|
||||
|
||||
@@ -173,10 +202,14 @@ class CudagraphDispatcher:
|
||||
for x in self.compilation_config.cudagraph_capture_sizes
|
||||
if x <= max_num_tokens and x >= uniform_decode_query_len
|
||||
]
|
||||
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
|
||||
for bs, num_active_loras in product(
|
||||
cudagraph_capture_sizes_for_decode, lora_cases
|
||||
):
|
||||
self.add_cudagraph_key(
|
||||
CUDAGraphMode.FULL,
|
||||
self._create_padded_batch_descriptor(bs, True, has_lora),
|
||||
self._create_padded_batch_descriptor(
|
||||
bs, True, num_active_loras > 0, num_active_loras
|
||||
),
|
||||
)
|
||||
|
||||
self.keys_initialized = True
|
||||
@@ -187,6 +220,7 @@ class CudagraphDispatcher:
|
||||
uniform_decode: bool = False,
|
||||
has_lora: bool = False,
|
||||
disable_full: bool = False,
|
||||
num_active_loras: int = 0,
|
||||
) -> tuple[CUDAGraphMode, BatchDescriptor]:
|
||||
"""
|
||||
Given conditions(e.g.,batch descriptor and if using piecewise only),
|
||||
@@ -202,6 +236,7 @@ class CudagraphDispatcher:
|
||||
disable_full: If True, skip FULL cudagraph checks and
|
||||
return PIECEWISE or NONE only. (can be used for features like
|
||||
cascade attention that are not supported by full cudagraphs)
|
||||
num_active_loras: Number of distinct active LoRA adapters.
|
||||
"""
|
||||
if (
|
||||
not self.keys_initialized
|
||||
@@ -210,8 +245,24 @@ class CudagraphDispatcher:
|
||||
):
|
||||
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
|
||||
|
||||
effective_num_active_loras = num_active_loras
|
||||
if has_lora and num_active_loras > 0:
|
||||
if self.specialize_lora_count:
|
||||
# Find the smallest captured `num_active_loras` that is >= the current
|
||||
# `num_active_loras`. This is because we only capture graphs for
|
||||
# a subset of possible `num_active_loras` values (powers of 2).
|
||||
import bisect
|
||||
|
||||
idx = bisect.bisect_left(self.captured_lora_counts, num_active_loras)
|
||||
if idx < len(self.captured_lora_counts):
|
||||
effective_num_active_loras = self.captured_lora_counts[idx]
|
||||
else:
|
||||
# When not specializing, graphs are captured only with max_loras + 1,
|
||||
# so we must use max_loras + 1 for dispatch to find a matching graph.
|
||||
effective_num_active_loras = self.vllm_config.lora_config.max_loras + 1
|
||||
|
||||
batch_desc = self._create_padded_batch_descriptor(
|
||||
num_tokens, uniform_decode, has_lora
|
||||
num_tokens, uniform_decode, has_lora, effective_num_active_loras
|
||||
)
|
||||
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
|
||||
|
||||
|
||||
@@ -3082,6 +3082,7 @@ class GPUModelRunner(
|
||||
# be improved in model runner v2)
|
||||
force_uniform_decode: bool | None = None,
|
||||
force_has_lora: bool | None = None,
|
||||
force_num_active_loras: int | None = None,
|
||||
num_encoder_reqs: int = 0,
|
||||
) -> tuple[
|
||||
CUDAGraphMode,
|
||||
@@ -3103,11 +3104,13 @@ class GPUModelRunner(
|
||||
self.model_config.is_encoder_decoder and num_encoder_reqs > 0
|
||||
)
|
||||
|
||||
has_lora = (
|
||||
len(self.input_batch.lora_id_to_lora_request) > 0
|
||||
if force_has_lora is None
|
||||
else force_has_lora
|
||||
# Compute LoRA state for cudagraph dispatch
|
||||
num_active_loras = (
|
||||
force_num_active_loras
|
||||
if force_num_active_loras is not None
|
||||
else len(self.input_batch.lora_id_to_lora_request)
|
||||
)
|
||||
has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora
|
||||
|
||||
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
|
||||
dispatch_cudagraph = (
|
||||
@@ -3116,6 +3119,7 @@ class GPUModelRunner(
|
||||
has_lora=has_lora,
|
||||
uniform_decode=uniform_decode,
|
||||
disable_full=disable_full,
|
||||
num_active_loras=num_active_loras,
|
||||
)
|
||||
if not force_eager
|
||||
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
|
||||
@@ -4606,8 +4610,8 @@ class GPUModelRunner(
|
||||
is_profile: bool = False,
|
||||
create_mixed_batch: bool = False,
|
||||
remove_lora: bool = True,
|
||||
activate_lora: bool = False,
|
||||
is_graph_capturing: bool = False,
|
||||
num_active_loras: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Run a dummy forward pass to warm up/profile run or capture the
|
||||
@@ -4630,7 +4634,8 @@ class GPUModelRunner(
|
||||
create_mixed_batch: If True, create a mixed batch with both decode
|
||||
(1 token) and prefill (multiple tokens) requests.
|
||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||
activate_lora: If False, dummy_run is performed without LoRAs.
|
||||
num_active_loras: Number of distinct active LoRAs to capture for.
|
||||
LoRA is activated when num_active_loras > 0.
|
||||
"""
|
||||
mm_config = self.vllm_config.model_config.multimodal_config
|
||||
if mm_config and mm_config.mm_encoder_only:
|
||||
@@ -4712,7 +4717,10 @@ class GPUModelRunner(
|
||||
# `force_has_lora` is used for cudagraph capture; because LoRA is
|
||||
# activated later in the context manager, but we need to know the
|
||||
# LoRA state when determining the batch descriptor for capture
|
||||
force_has_lora=activate_lora,
|
||||
force_has_lora=num_active_loras > 0,
|
||||
# `force_num_active_loras` is used for cudagraph capture; because we
|
||||
# need to capture graphs for specific num_active_loras counts
|
||||
force_num_active_loras=num_active_loras,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4782,8 +4790,8 @@ class GPUModelRunner(
|
||||
self.lora_config,
|
||||
num_scheduled_tokens,
|
||||
num_sampled_tokens,
|
||||
activate_lora,
|
||||
remove_lora,
|
||||
num_active_loras,
|
||||
):
|
||||
# Make sure padding doesn't exceed max_num_tokens
|
||||
assert num_tokens_padded <= self.max_num_tokens
|
||||
@@ -4884,7 +4892,10 @@ class GPUModelRunner(
|
||||
# lora cases when cudagraph_specialize_lora is enabled. This is a
|
||||
# short term mitigation for issue mentioned in
|
||||
# https://github.com/vllm-project/vllm/issues/28334
|
||||
if self.compilation_config.cudagraph_specialize_lora and activate_lora:
|
||||
if (
|
||||
self.compilation_config.cudagraph_specialize_lora
|
||||
and num_active_loras > 0
|
||||
):
|
||||
use_cudagraphs = False
|
||||
|
||||
self.drafter.dummy_run(
|
||||
@@ -5259,7 +5270,7 @@ class GPUModelRunner(
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for batch_desc in batch_descriptors:
|
||||
num_tokens = batch_desc.num_tokens
|
||||
activate_lora = batch_desc.has_lora
|
||||
num_active_loras = batch_desc.num_active_loras
|
||||
|
||||
# We currently only capture ubatched graphs when its a FULL
|
||||
# cudagraph, a uniform decode batch, and the number of tokens
|
||||
@@ -5286,7 +5297,7 @@ class GPUModelRunner(
|
||||
num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
allow_microbatching=allow_microbatching,
|
||||
activate_lora=activate_lora,
|
||||
num_active_loras=num_active_loras,
|
||||
)
|
||||
|
||||
# Capture run
|
||||
@@ -5294,7 +5305,7 @@ class GPUModelRunner(
|
||||
num_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
allow_microbatching=allow_microbatching,
|
||||
activate_lora=activate_lora,
|
||||
num_active_loras=num_active_loras,
|
||||
is_graph_capturing=True,
|
||||
)
|
||||
self.maybe_remove_all_loras(self.lora_config)
|
||||
|
||||
@@ -133,11 +133,23 @@ class LoRAModelRunnerMixin:
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
num_sampled_tokens: np.ndarray | None = None,
|
||||
activate_lora: bool = True,
|
||||
num_active_loras: int = 0,
|
||||
):
|
||||
"""
|
||||
Context manager to select dummy LoRAs for capture/warmup.
|
||||
|
||||
Args:
|
||||
lora_config: LoRA configuration, or None if LoRA is disabled.
|
||||
num_scheduled_tokens: Array of scheduled token counts per request.
|
||||
num_sampled_tokens: Array of sampled token counts per request.
|
||||
num_active_loras: Number of distinct active LoRAs to use.
|
||||
- 0: No LoRA active (set up zero mappings).
|
||||
- >0: Use exactly this many distinct LoRAs.
|
||||
"""
|
||||
if num_sampled_tokens is None:
|
||||
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
|
||||
|
||||
# Skip LoRA setup entirely only if no LoRA config
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
@@ -145,15 +157,52 @@ class LoRAModelRunnerMixin:
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_reqs = len(num_scheduled_tokens)
|
||||
num_loras = lora_config.max_loras
|
||||
max_loras = lora_config.max_loras
|
||||
|
||||
# Determine how many distinct LoRAs to use and whether to include
|
||||
# no-LoRA tokens (-1 entries).
|
||||
# When num_active_loras > max_loras (e.g., max_loras + 1), we need
|
||||
# to include -1 entries to simulate batches with both LoRA and
|
||||
# no-LoRA tokens. This ensures prepare_tensors computes the correct
|
||||
# num_active_loras that matches the cudagraph capture key.
|
||||
if num_active_loras == 0:
|
||||
# No LoRA active - use 0 mappings like the original code
|
||||
effective_num_loras = 0
|
||||
include_no_lora = False
|
||||
elif num_active_loras > max_loras:
|
||||
# num_active_loras > max_loras means we want max_loras adapters
|
||||
# PLUS no-LoRA tokens (-1). This is the max_loras + 1 case.
|
||||
effective_num_loras = max_loras
|
||||
include_no_lora = True
|
||||
else:
|
||||
# Specific number of active LoRAs requested
|
||||
effective_num_loras = min(num_active_loras, max_loras)
|
||||
include_no_lora = False
|
||||
|
||||
# Make prompt lora mapping
|
||||
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
||||
if activate_lora:
|
||||
prompt_lora_mapping = (
|
||||
np.arange(num_reqs, dtype=np.int32) % num_loras
|
||||
) + 1
|
||||
# LoRA IDs are 1-indexed (1 to max_loras) as required by LoRARequest.
|
||||
# convert_mapping() will convert these to 0-indexed slot indices.
|
||||
if effective_num_loras > 0:
|
||||
if include_no_lora:
|
||||
# Include -1 (no-LoRA) entries by cycling through
|
||||
# -1, 1, 2, ..., effective_num_loras
|
||||
# This ensures prepare_tensors sees both LoRA and no-LoRA
|
||||
# tokens, computing num_active_loras = effective_num_loras+1
|
||||
cycle_values = np.array(
|
||||
list(range(1, effective_num_loras + 1)),
|
||||
dtype=np.int32,
|
||||
)
|
||||
prompt_lora_mapping = cycle_values[
|
||||
np.arange(num_reqs, dtype=np.int32) % len(cycle_values)
|
||||
]
|
||||
else:
|
||||
# Use 1 to effective_num_loras (1-indexed lora IDs)
|
||||
prompt_lora_mapping = (
|
||||
np.arange(num_reqs, dtype=np.int32) % effective_num_loras
|
||||
) + 1
|
||||
else:
|
||||
# No LoRA active - use 0 for all tokens (original behavior)
|
||||
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
|
||||
|
||||
# Make sample lora mapping
|
||||
@@ -162,14 +211,14 @@ class LoRAModelRunnerMixin:
|
||||
# Make token lora mapping
|
||||
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
|
||||
|
||||
# Make dummy lora requests
|
||||
# Make dummy lora requests (only for the active LoRAs)
|
||||
lora_requests: set[LoRARequest] = {
|
||||
LoRARequest(
|
||||
lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path",
|
||||
)
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
for lora_id in range(1, effective_num_loras + 1)
|
||||
}
|
||||
|
||||
self._set_active_loras(
|
||||
@@ -187,10 +236,21 @@ class LoRAModelRunnerMixin:
|
||||
lora_config: LoRAConfig | None,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_sampled_tokens: np.ndarray,
|
||||
activate_lora: bool = True,
|
||||
remove_lora: bool = True,
|
||||
num_active_loras: int = 0,
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
):
|
||||
"""
|
||||
Context manager for dummy runs with LoRA.
|
||||
|
||||
Args:
|
||||
lora_config: LoRA configuration.
|
||||
num_scheduled_tokens: Array of scheduled token counts per request.
|
||||
num_sampled_tokens: Array of sampled token counts per request.
|
||||
remove_lora: Whether to remove LoRAs after the context exits.
|
||||
num_active_loras: Number of distinct active LoRAs to use.
|
||||
LoRA is activated when num_active_loras > 0.
|
||||
"""
|
||||
with (
|
||||
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
||||
self.maybe_select_dummy_loras(
|
||||
@@ -198,7 +258,7 @@ class LoRAModelRunnerMixin:
|
||||
num_scheduled_tokens,
|
||||
mapping_type,
|
||||
num_sampled_tokens,
|
||||
activate_lora,
|
||||
num_active_loras,
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
Reference in New Issue
Block a user