diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index 3df3a606c..f3c3cb8cf 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -187,7 +187,8 @@ def use_fused_moe_lora_kernel( # num_active_loras is the number of active LoRAs # (max_loras + 1 to include no-lora case) - num_active_loras = max_loras + 1 + # Stored as CPU tensor to match the kernel API (torch.compile compatibility) + num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") fused_moe_lora( output, @@ -399,7 +400,8 @@ def use_fused_moe_lora_kernel_naive( # num_active_loras is the number of active LoRAs # (max_loras + 1 to include no-lora case) - num_active_loras = max_loras + 1 + # Stored as CPU tensor to match the kernel API (torch.compile compatibility) + num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") fused_moe_lora( output, diff --git a/tests/lora/test_gptoss_tp.py b/tests/lora/test_gptoss_tp.py index 14d0ff47d..855b6b796 100644 --- a/tests/lora/test_gptoss_tp.py +++ b/tests/lora/test_gptoss_tp.py @@ -70,8 +70,12 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: @pytest.mark.parametrize("mxfp4_use_marlin", [True, False]) +@pytest.mark.parametrize("specialize_active_lora", [True, False]) def test_gpt_oss_lora( - monkeypatch: pytest.MonkeyPatch, gptoss20b_lora_files, mxfp4_use_marlin + monkeypatch: pytest.MonkeyPatch, + gptoss20b_lora_files, + mxfp4_use_marlin, + specialize_active_lora, ): with monkeypatch.context() as m: m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0") @@ -83,6 +87,7 @@ def test_gpt_oss_lora( max_lora_rank=8, max_num_seqs=2, max_num_batched_tokens=2048, + specialize_active_lora=specialize_active_lora, compilation_config=vllm.config.CompilationConfig( # Avoid OOM cudagraph_specialize_lora=False, ), diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 8072f8769..7fc49d8d8 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -127,7 +127,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): def _adjust_kernel_inputs( - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs sorted_token_ids: torch.Tensor | None, expert_ids: torch.Tensor, ): @@ -141,7 +141,7 @@ def _adjust_kernel_inputs( else: stride_tl = sorted_token_ids.stride(0) stride_el = expert_ids.stride(0) - grid_lora_dim = num_active_loras + grid_lora_dim = num_active_loras.item() return grid_lora_dim, stride_tl, stride_el @@ -444,7 +444,7 @@ def _fused_moe_lora_shrink( num_warps: int, num_stages: int, split_k: int, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs mul_routed_weight: bool = False, use_gdc: bool = False, use_tma: bool = False, @@ -562,7 +562,7 @@ def _fused_moe_lora_expand( num_warps: int, num_stages: int, split_k: int, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs mul_routed_weight: bool = False, offset: int = 0, use_gdc: bool = False, @@ -683,7 +683,7 @@ def _fused_moe_lora( max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs adapter_enabled: torch.Tensor, shrink_block_size_m: int, shrink_block_size_n: int, @@ -871,7 +871,7 @@ def _fused_moe_lora_fake( max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs adapter_enabled: torch.Tensor, shrink_block_size_m: int, shrink_block_size_n: int, @@ -921,7 +921,7 @@ def _fused_moe_lora_shrink_fake( num_warps: int, num_stages: int, split_k: int, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs mul_routed_weight: bool = False, use_gdc: bool = False, use_tma: bool = False, @@ -958,7 +958,7 @@ def _fused_moe_lora_expand_fake( num_warps: int, num_stages: int, split_k: int, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs mul_routed_weight: bool = False, offset: int = 0, use_gdc: bool = False, diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index 1557d37d2..343e0c810 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -138,7 +138,7 @@ def _lora_expand( lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] no_lora_flag_cpu: torch.Tensor, # shape [1] - num_active_loras: int, # number of active LoRAs (unused here, for API compat) + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs offset_start: int = 0, add_inputs: bool = False, ) -> None: @@ -235,7 +235,7 @@ def _lora_expand( grid = ( triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), NUM_SLICES, - num_active_loras, + num_active_loras.item(), ) # We disable PDL temporarily because LoRA kernels are not launching back-to-back, # making PDL invalid and affecting the kernel performance. @@ -289,7 +289,7 @@ def _lora_expand_fake( lora_token_start_loc: torch.Tensor, lora_ids: torch.Tensor, no_lora_flag_cpu: torch.Tensor, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs offset_start: int = 0, add_inputs: bool = False, ) -> None: diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py index 1fec1d50c..dd7c2c706 100644 --- a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -29,9 +29,16 @@ class LoRAKernelMeta: # to early exit from inside the lora_expand / lora_shrink torch operation. no_lora_flag_cpu: torch.Tensor - # Number of active LoRAs (unique non-(-1) values in token_lora_mapping) - # Stored as a Python int to avoid GPU->CPU sync during forward pass - num_active_loras: int = 0 + # Number of active LoRAs (unique non-(-1) values in token_lora_mapping). + # Stored as a CPU tensor (not a Python int) so that torch.compile treats + # it as a dynamic value rather than baking it as a constant at trace time. + # This follows the same pattern as no_lora_flag_cpu above. + num_active_loras_cpu: torch.Tensor + + # Default num_active_loras value (max_loras + 1) as a CPU tensor, + # used when specialize_active_lora is False to avoid allocating a + # new tensor on every meta_args() call. + default_num_active_loras_cpu: torch.Tensor # Captured LoRA counts for cudagraph specialization (sorted list). # When specialize_active_lora is enabled, num_active_loras is rounded up @@ -73,6 +80,11 @@ class LoRAKernelMeta: no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu") + num_active_loras_cpu = torch.tensor([0], dtype=torch.int32, device="cpu") + default_num_active_loras_cpu = torch.tensor( + [max_loras + 1], dtype=torch.int32, device="cpu" + ) + return LoRAKernelMeta( token_lora_mapping=token_lora_mapping, token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, @@ -80,6 +92,8 @@ class LoRAKernelMeta: num_tokens_per_lora=num_tokens_per_lora, lora_token_start_loc=lora_token_start_loc, no_lora_flag_cpu=no_lora_flag_cpu, + num_active_loras_cpu=num_active_loras_cpu, + default_num_active_loras_cpu=default_num_active_loras_cpu, captured_lora_counts=sorted(captured_lora_counts) if captured_lora_counts else [], @@ -90,8 +104,7 @@ class LoRAKernelMeta: self.num_tokens_per_lora.fill_(0) self.lora_token_start_loc.fill_(0) self.no_lora_flag_cpu.fill_(False) - self.num_active_loras = 0 - self.captured_lora_counts = [] + self.num_active_loras_cpu.fill_(0) def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: """ @@ -137,14 +150,16 @@ class LoRAKernelMeta: num_tokens_per_lora, non_blocking=True ) - self.num_active_loras = lora_ids.size(0) + num_active_loras = lora_ids.size(0) # Round up num_active_loras to match cudagraph capture keys. # This ensures the kernel grid dimension matches the captured graph. - if self.captured_lora_counts and self.num_active_loras > 0: - idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras) + if self.captured_lora_counts and num_active_loras > 0: + idx = bisect.bisect_left(self.captured_lora_counts, num_active_loras) if idx < len(self.captured_lora_counts): - self.num_active_loras = self.captured_lora_counts[idx] + num_active_loras = self.captured_lora_counts[idx] + + self.num_active_loras_cpu[0] = num_active_loras # lora_token_start_loc lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) @@ -163,7 +178,7 @@ class LoRAKernelMeta: torch.Tensor, torch.Tensor, torch.Tensor, - int, + torch.Tensor, ]: """ This function returns the kernel metadata required for the current @@ -175,7 +190,10 @@ class LoRAKernelMeta: token_nums (int): Number of input tokens in the current forward pass of the kernel. """ - max_loras = self.active_lora_ids.size(0) - 1 + if specialize_active_lora: + num_active_loras = self.num_active_loras_cpu + else: + num_active_loras = self.default_num_active_loras_cpu return ( self.token_lora_mapping[:token_nums], self.token_indices_sorted_by_lora_ids[:token_nums], @@ -183,5 +201,5 @@ class LoRAKernelMeta: self.lora_token_start_loc, self.active_lora_ids, self.no_lora_flag_cpu, - self.num_active_loras if specialize_active_lora else max_loras + 1, + num_active_loras, ) diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 8dbd988f7..ea850baa2 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -134,7 +134,7 @@ def _lora_shrink( lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] no_lora_flag_cpu: torch.Tensor, # shape [1] - num_active_loras: int, # number of active LoRAs (unused here, for API compat) + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs scaling: float, ) -> None: """ @@ -157,6 +157,9 @@ def _lora_shrink( lora_ids (torch.Tensor): LoRA ids to process. no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates if there are any requests that require LoRA. + num_active_loras (torch.Tensor): A CPU tensor of size 1, containing the + number of active LoRAs. Stored as a tensor (not int) so + torch.compile treats it as dynamic rather than a constant. scaling (float): Scaling factor. """ @@ -215,7 +218,7 @@ def _lora_shrink( grid = ( SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), NUM_SLICES, - num_active_loras, + num_active_loras.item(), ) # We disable PDL temporarily because LoRA kernels are not launching back-to-back, # making PDL invalid and affecting the kernel performance. @@ -267,7 +270,7 @@ def _lora_shrink_fake( lora_token_start_loc: torch.Tensor, lora_ids: torch.Tensor, no_lora_flag_cpu: torch.Tensor, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs scaling: float, ) -> None: return diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 59a82d4ce..36abee66e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5379,6 +5379,7 @@ class GPUModelRunner( # if we want to warm up attention or not. This is # different from the case where `FULL` implies capture # attention while `PIECEWISE` implies no attention. + dummy_run( num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE,