[LoRA][Spec Decode] Support LoRA for Nemotron-H MTP models (#32265)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
danisereb
2026-01-27 17:53:26 +02:00
committed by GitHub
parent 7cbbca9aaa
commit f3a5ee705f
4 changed files with 31 additions and 0 deletions

View File

@@ -521,6 +521,8 @@ class SupportsLoRA(Protocol):
# are empty by default.
embedding_modules: ClassVar[dict[str, str]] = {}
packed_modules_mapping: dict[str, list[str]] = {}
# Module prefixes to skip during LoRA loading (e.g., ["mtp."] for MTP layers)
lora_skip_prefixes: ClassVar[list[str]] = []
# We can't use runtime_checkable with ClassVar for issubclass checks

View File

@@ -771,6 +771,9 @@ class NemotronHForCausalLM(
"lm_head": "output_embeddings",
}
# Skip MTP (Multi-Token Prediction) layers during LoRA loading
lora_skip_prefixes = ["mtp."]
@classmethod
def get_mamba_state_dtype_from_config(
cls,