diff --git a/vllm/lora/lora_model.py b/vllm/lora/lora_model.py index bc88c71ea..e9e0a711a 100644 --- a/vllm/lora/lora_model.py +++ b/vllm/lora/lora_model.py @@ -62,6 +62,14 @@ class LoRAModel: def check_lora_name(self, lora_name: str) -> bool: return lora_name in self.loras + @staticmethod + def _should_skip_module(module_name: str, skip_prefixes: list[str]) -> bool: + """Check if a module should be skipped based on skip prefixes""" + for prefix in skip_prefixes: + if f".{prefix}" in module_name or module_name.startswith(prefix): + return True + return False + @classmethod def from_lora_tensors( cls, @@ -72,6 +80,7 @@ class LoRAModel: dtype: torch.dtype | None = None, model_vocab_size: int | None = None, weights_mapper: WeightsMapper | None = None, + skip_prefixes: list[str] | None = None, ) -> "LoRAModel": """Create a LoRAModel from a dictionary of tensors.""" pin_memory = str(device) == "cpu" and is_pin_memory_available() @@ -79,6 +88,9 @@ class LoRAModel: for tensor_name, tensor in tensors.items(): if is_base_embeddding_weights(tensor_name): continue + # Skip modules based on model-defined prefixes (e.g., MTP layers) + if skip_prefixes and cls._should_skip_module(tensor_name, skip_prefixes): + continue module_name, is_lora_a = parse_fine_tuned_lora_name( tensor_name, weights_mapper ) @@ -121,6 +133,7 @@ class LoRAModel: model_vocab_size: int | None = None, weights_mapper: WeightsMapper | None = None, tensorizer_config_dict: dict | None = None, + skip_prefixes: list[str] | None = None, ) -> "LoRAModel": """Create a LoRAModel from a local checkpoint. @@ -133,6 +146,9 @@ class LoRAModel: a global counter. device: Device where the lora model is loaded. dtype: dtype of the lora model weights. + skip_prefixes: List of module name prefixes to skip during loading. + Models can define this to skip modules not used in inference + (e.g., MTP layers). Format: ["mtp."] Returns: Loaded LoRA Model. @@ -152,6 +168,11 @@ class LoRAModel: # gate_up_proj and experts is the down_proj if "base_layer" in lora_module: continue + # Skip modules based on model-defined prefixes + if skip_prefixes and cls._should_skip_module( + lora_module, skip_prefixes + ): + continue module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper) # Case for expert lora weights if ".experts" in module_name: @@ -218,4 +239,5 @@ class LoRAModel: dtype=dtype, model_vocab_size=model_vocab_size, weights_mapper=weights_mapper, + skip_prefixes=skip_prefixes, ) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 598c10407..2db747e2c 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -114,6 +114,9 @@ class WorkerLoRAManager: model = self._adapter_manager.model hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None) + # Get model-defined prefixes to skip during LoRA loading. + lora_skip_prefixes = getattr(model, "lora_skip_prefixes", None) + lora = self._lora_model_cls.from_local_checkpoint( lora_path, expected_lora_modules, @@ -124,6 +127,7 @@ class WorkerLoRAManager: model_vocab_size=self.vocab_size, tensorizer_config_dict=lora_request.tensorizer_config_dict, weights_mapper=hf_to_vllm_mapper, + skip_prefixes=lora_skip_prefixes, ) except FileNotFoundError as e: diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 65588ac33..ea763afd5 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -521,6 +521,8 @@ class SupportsLoRA(Protocol): # are empty by default. embedding_modules: ClassVar[dict[str, str]] = {} packed_modules_mapping: dict[str, list[str]] = {} + # Module prefixes to skip during LoRA loading (e.g., ["mtp."] for MTP layers) + lora_skip_prefixes: ClassVar[list[str]] = [] # We can't use runtime_checkable with ClassVar for issubclass checks diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index e33bbe9fa..a11224657 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -771,6 +771,9 @@ class NemotronHForCausalLM( "lm_head": "output_embeddings", } + # Skip MTP (Multi-Token Prediction) layers during LoRA loading + lora_skip_prefixes = ["mtp."] + @classmethod def get_mamba_state_dtype_from_config( cls,