diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index fc1be4ed4..30b74ce3e 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -52,6 +52,7 @@ def test_worker_apply_lora(sql_lora_files): seed=0, dtype="float16", revision=None, + enforce_eager=True, ), load_config=LoadConfig( download_dir=None, diff --git a/vllm/config.py b/vllm/config.py index 35411ca73..429ec0dd5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2287,9 +2287,14 @@ class LoRAConfig: excluding anything before input ids/embeddings and after the final hidden states. """ - # no factors to consider. - # LoRA is not compatible with `torch.compile` . factors: list[Any] = [] + factors.append(self.max_lora_rank) + factors.append(self.max_loras) + factors.append(self.fully_sharded_loras) + factors.append(self.lora_dtype) + factors.append(self.lora_extra_vocab_size) + factors.append(self.long_lora_scaling_factors) + factors.append(self.bias_enabled) hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str @@ -3303,6 +3308,11 @@ class VllmConfig: vllm_factors.append("None") if self.lora_config: vllm_factors.append(self.lora_config.compute_hash()) + # LoRA creates static buffers based on max_num_batched_tokens. + # The tensor sizes and strides get captured in the torch.compile + # graph explicitly. + vllm_factors.append( + str(self.scheduler_config.max_num_batched_tokens)) else: vllm_factors.append("None") if self.speculative_config: @@ -3453,12 +3463,15 @@ class VllmConfig: " Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION - if self.lora_config is not None and self.compilation_config.level !=\ - CompilationLevel.NO_COMPILATION: - logger.warning("LoRA is not supported with `torch.compile` yet. " - "Disabling `torch.compile`.") + if ((not envs.VLLM_USE_V1) and self.lora_config is not None + and self.compilation_config.level + != CompilationLevel.NO_COMPILATION): + logger.warning( + "LoRA for V0 is not supported with `torch.compile` yet. " + "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION + if self.model_config and self.model_config.use_mla and \ not (current_platform.is_cuda() or current_platform.is_rocm()): logger.info( diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 1c1f76702..7a9d5237a 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -237,16 +237,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) def forward(self, x: torch.Tensor) -> torch.Tensor: - added_tokens_mask = x > self.base_layer.org_vocab_size - 1 - embeddings_indices = self.punica_wrapper.embeddings_indices - indices = embeddings_indices[1].view_as(x) + added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, + 1, 0) + embeddings_indices = torch.narrow( + self.punica_wrapper._embeddings_indices, 1, 0, x.size(0)) + + indices = embeddings_indices[1] full_lora_a_embeddings = F.embedding( x + indices, self.lora_a_stacked_2d, ) - indices = embeddings_indices[0].view_as(x) - full_output = self.base_layer.forward( - x.add_(indices * added_tokens_mask)) + indices = embeddings_indices[0] + full_output = self.base_layer.forward(x + + (indices * added_tokens_mask)) full_output_org = full_output if full_output.ndim == 3: diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 3a4fcd04d..19a94eea9 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -254,7 +254,9 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin): y_org = y y = y.view(-1, y.shape[-1]) if lora_bias_stacked is not None: - self._apply_bias(self.token_lora_indices, y, output_slices, + token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, + y.size(0)) + self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked) if env.VLLM_USE_V1: @@ -365,7 +367,9 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin): assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) if lora_bias_stacked is not None: assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias(self.token_lora_indices, y, output_slices, + token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, + y.size(0)) + y = self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked) if buffer is None: