[LoRA][Spec Decode] Support LoRA for Nemotron-H MTP models (#32265)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -521,6 +521,8 @@ class SupportsLoRA(Protocol):
|
||||
# are empty by default.
|
||||
embedding_modules: ClassVar[dict[str, str]] = {}
|
||||
packed_modules_mapping: dict[str, list[str]] = {}
|
||||
# Module prefixes to skip during LoRA loading (e.g., ["mtp."] for MTP layers)
|
||||
lora_skip_prefixes: ClassVar[list[str]] = []
|
||||
|
||||
|
||||
# We can't use runtime_checkable with ClassVar for issubclass checks
|
||||
|
||||
@@ -771,6 +771,9 @@ class NemotronHForCausalLM(
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
|
||||
# Skip MTP (Multi-Token Prediction) layers during LoRA loading
|
||||
lora_skip_prefixes = ["mtp."]
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user