[LoRA][Spec Decode] Support LoRA for Nemotron-H MTP models (#32265)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -62,6 +62,14 @@ class LoRAModel:
|
||||
def check_lora_name(self, lora_name: str) -> bool:
|
||||
return lora_name in self.loras
|
||||
|
||||
@staticmethod
|
||||
def _should_skip_module(module_name: str, skip_prefixes: list[str]) -> bool:
|
||||
"""Check if a module should be skipped based on skip prefixes"""
|
||||
for prefix in skip_prefixes:
|
||||
if f".{prefix}" in module_name or module_name.startswith(prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def from_lora_tensors(
|
||||
cls,
|
||||
@@ -72,6 +80,7 @@ class LoRAModel:
|
||||
dtype: torch.dtype | None = None,
|
||||
model_vocab_size: int | None = None,
|
||||
weights_mapper: WeightsMapper | None = None,
|
||||
skip_prefixes: list[str] | None = None,
|
||||
) -> "LoRAModel":
|
||||
"""Create a LoRAModel from a dictionary of tensors."""
|
||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||
@@ -79,6 +88,9 @@ class LoRAModel:
|
||||
for tensor_name, tensor in tensors.items():
|
||||
if is_base_embeddding_weights(tensor_name):
|
||||
continue
|
||||
# Skip modules based on model-defined prefixes (e.g., MTP layers)
|
||||
if skip_prefixes and cls._should_skip_module(tensor_name, skip_prefixes):
|
||||
continue
|
||||
module_name, is_lora_a = parse_fine_tuned_lora_name(
|
||||
tensor_name, weights_mapper
|
||||
)
|
||||
@@ -121,6 +133,7 @@ class LoRAModel:
|
||||
model_vocab_size: int | None = None,
|
||||
weights_mapper: WeightsMapper | None = None,
|
||||
tensorizer_config_dict: dict | None = None,
|
||||
skip_prefixes: list[str] | None = None,
|
||||
) -> "LoRAModel":
|
||||
"""Create a LoRAModel from a local checkpoint.
|
||||
|
||||
@@ -133,6 +146,9 @@ class LoRAModel:
|
||||
a global counter.
|
||||
device: Device where the lora model is loaded.
|
||||
dtype: dtype of the lora model weights.
|
||||
skip_prefixes: List of module name prefixes to skip during loading.
|
||||
Models can define this to skip modules not used in inference
|
||||
(e.g., MTP layers). Format: ["mtp."]
|
||||
|
||||
Returns:
|
||||
Loaded LoRA Model.
|
||||
@@ -152,6 +168,11 @@ class LoRAModel:
|
||||
# gate_up_proj and experts is the down_proj
|
||||
if "base_layer" in lora_module:
|
||||
continue
|
||||
# Skip modules based on model-defined prefixes
|
||||
if skip_prefixes and cls._should_skip_module(
|
||||
lora_module, skip_prefixes
|
||||
):
|
||||
continue
|
||||
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
|
||||
# Case for expert lora weights
|
||||
if ".experts" in module_name:
|
||||
@@ -218,4 +239,5 @@ class LoRAModel:
|
||||
dtype=dtype,
|
||||
model_vocab_size=model_vocab_size,
|
||||
weights_mapper=weights_mapper,
|
||||
skip_prefixes=skip_prefixes,
|
||||
)
|
||||
|
||||
@@ -114,6 +114,9 @@ class WorkerLoRAManager:
|
||||
model = self._adapter_manager.model
|
||||
hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
|
||||
|
||||
# Get model-defined prefixes to skip during LoRA loading.
|
||||
lora_skip_prefixes = getattr(model, "lora_skip_prefixes", None)
|
||||
|
||||
lora = self._lora_model_cls.from_local_checkpoint(
|
||||
lora_path,
|
||||
expected_lora_modules,
|
||||
@@ -124,6 +127,7 @@ class WorkerLoRAManager:
|
||||
model_vocab_size=self.vocab_size,
|
||||
tensorizer_config_dict=lora_request.tensorizer_config_dict,
|
||||
weights_mapper=hf_to_vllm_mapper,
|
||||
skip_prefixes=lora_skip_prefixes,
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
|
||||
@@ -521,6 +521,8 @@ class SupportsLoRA(Protocol):
|
||||
# are empty by default.
|
||||
embedding_modules: ClassVar[dict[str, str]] = {}
|
||||
packed_modules_mapping: dict[str, list[str]] = {}
|
||||
# Module prefixes to skip during LoRA loading (e.g., ["mtp."] for MTP layers)
|
||||
lora_skip_prefixes: ClassVar[list[str]] = []
|
||||
|
||||
|
||||
# We can't use runtime_checkable with ClassVar for issubclass checks
|
||||
|
||||
@@ -771,6 +771,9 @@ class NemotronHForCausalLM(
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
|
||||
# Skip MTP (Multi-Token Prediction) layers during LoRA loading
|
||||
lora_skip_prefixes = ["mtp."]
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user