Reduce the kernel overhead when num of active loras is smaller than max loras. Multiple cuda graphs are captured for each num of active-loras. (#32005)

Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
yugong333
2026-02-02 09:30:06 -08:00
committed by GitHub
parent 8b7346d5f1
commit ffe1fc7a28
15 changed files with 323 additions and 66 deletions

View File

@@ -17,6 +17,7 @@ from vllm.config import (
SchedulerConfig,
VllmConfig,
)
from vllm.config.lora import LoRAConfig
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
@@ -47,6 +48,12 @@ def _create_vllm_config(
mock_config.speculative_config = None # No speculative decoding
if not lora_config:
mock_config.lora_config = None
else:
# Create a real LoRAConfig with specialize_active_lora enabled
mock_config.lora_config = LoRAConfig(
max_loras=4,
specialize_active_lora=True,
)
# Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1(
@@ -106,15 +113,19 @@ class TestCudagraphDispatcher:
)
# Verify the key is initialized correctly
# With LoRA specialization (max_loras=4, specialize_active_lora=True):
# - lora_cases = [0, 1, 2, 4, 5] (no-lora + powers of 2 up to 4 + max_loras+1)
# - capture_sizes = [1, 8]
# - Total keys = 2 sizes × 5 lora_cases = 10
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
4 if lora_config else 2
10 if lora_config else 2
)
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
4 if lora_config else 2
10 if lora_config else 2
)
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0