[LoRA][Spec Decode] Support LoRA for Nemotron-H MTP models (#32265)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
danisereb
2026-01-27 17:53:26 +02:00
committed by GitHub
parent 7cbbca9aaa
commit f3a5ee705f
4 changed files with 31 additions and 0 deletions

View File

@@ -62,6 +62,14 @@ class LoRAModel:
def check_lora_name(self, lora_name: str) -> bool:
return lora_name in self.loras
@staticmethod
def _should_skip_module(module_name: str, skip_prefixes: list[str]) -> bool:
"""Check if a module should be skipped based on skip prefixes"""
for prefix in skip_prefixes:
if f".{prefix}" in module_name or module_name.startswith(prefix):
return True
return False
@classmethod
def from_lora_tensors(
cls,
@@ -72,6 +80,7 @@ class LoRAModel:
dtype: torch.dtype | None = None,
model_vocab_size: int | None = None,
weights_mapper: WeightsMapper | None = None,
skip_prefixes: list[str] | None = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and is_pin_memory_available()
@@ -79,6 +88,9 @@ class LoRAModel:
for tensor_name, tensor in tensors.items():
if is_base_embeddding_weights(tensor_name):
continue
# Skip modules based on model-defined prefixes (e.g., MTP layers)
if skip_prefixes and cls._should_skip_module(tensor_name, skip_prefixes):
continue
module_name, is_lora_a = parse_fine_tuned_lora_name(
tensor_name, weights_mapper
)
@@ -121,6 +133,7 @@ class LoRAModel:
model_vocab_size: int | None = None,
weights_mapper: WeightsMapper | None = None,
tensorizer_config_dict: dict | None = None,
skip_prefixes: list[str] | None = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.
@@ -133,6 +146,9 @@ class LoRAModel:
a global counter.
device: Device where the lora model is loaded.
dtype: dtype of the lora model weights.
skip_prefixes: List of module name prefixes to skip during loading.
Models can define this to skip modules not used in inference
(e.g., MTP layers). Format: ["mtp."]
Returns:
Loaded LoRA Model.
@@ -152,6 +168,11 @@ class LoRAModel:
# gate_up_proj and experts is the down_proj
if "base_layer" in lora_module:
continue
# Skip modules based on model-defined prefixes
if skip_prefixes and cls._should_skip_module(
lora_module, skip_prefixes
):
continue
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
# Case for expert lora weights
if ".experts" in module_name:
@@ -218,4 +239,5 @@ class LoRAModel:
dtype=dtype,
model_vocab_size=model_vocab_size,
weights_mapper=weights_mapper,
skip_prefixes=skip_prefixes,
)

View File

@@ -114,6 +114,9 @@ class WorkerLoRAManager:
model = self._adapter_manager.model
hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
# Get model-defined prefixes to skip during LoRA loading.
lora_skip_prefixes = getattr(model, "lora_skip_prefixes", None)
lora = self._lora_model_cls.from_local_checkpoint(
lora_path,
expected_lora_modules,
@@ -124,6 +127,7 @@ class WorkerLoRAManager:
model_vocab_size=self.vocab_size,
tensorizer_config_dict=lora_request.tensorizer_config_dict,
weights_mapper=hf_to_vllm_mapper,
skip_prefixes=lora_skip_prefixes,
)
except FileNotFoundError as e:

View File

@@ -521,6 +521,8 @@ class SupportsLoRA(Protocol):
# are empty by default.
embedding_modules: ClassVar[dict[str, str]] = {}
packed_modules_mapping: dict[str, list[str]] = {}
# Module prefixes to skip during LoRA loading (e.g., ["mtp."] for MTP layers)
lora_skip_prefixes: ClassVar[list[str]] = []
# We can't use runtime_checkable with ClassVar for issubclass checks

View File

@@ -771,6 +771,9 @@ class NemotronHForCausalLM(
"lm_head": "output_embeddings",
}
# Skip MTP (Multi-Token Prediction) layers during LoRA loading
lora_skip_prefixes = ["mtp."]
@classmethod
def get_mamba_state_dtype_from_config(
cls,