Reduce the kernel overhead when num of active loras is smaller than max loras. Multiple cuda graphs are captured for each num of active-loras. (#32005)
Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from vllm.config import (
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
@@ -47,6 +48,12 @@ def _create_vllm_config(
|
||||
mock_config.speculative_config = None # No speculative decoding
|
||||
if not lora_config:
|
||||
mock_config.lora_config = None
|
||||
else:
|
||||
# Create a real LoRAConfig with specialize_active_lora enabled
|
||||
mock_config.lora_config = LoRAConfig(
|
||||
max_loras=4,
|
||||
specialize_active_lora=True,
|
||||
)
|
||||
# Mimic the behavior of VllmConfig.__post_init__()
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
compilation_config.set_splitting_ops_for_v1(
|
||||
@@ -106,15 +113,19 @@ class TestCudagraphDispatcher:
|
||||
)
|
||||
|
||||
# Verify the key is initialized correctly
|
||||
# With LoRA specialization (max_loras=4, specialize_active_lora=True):
|
||||
# - lora_cases = [0, 1, 2, 4, 5] (no-lora + powers of 2 up to 4 + max_loras+1)
|
||||
# - capture_sizes = [1, 8]
|
||||
# - Total keys = 2 sizes × 5 lora_cases = 10
|
||||
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
|
||||
4 if lora_config else 2
|
||||
10 if lora_config else 2
|
||||
)
|
||||
else:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
|
||||
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
|
||||
4 if lora_config else 2
|
||||
10 if lora_config else 2
|
||||
)
|
||||
else:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
|
||||
|
||||
Reference in New Issue
Block a user