[Hardware][NV] Fix Modelopt model loading for k-v-scales for Llama models. (#11787)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Pavani Majety
2025-01-29 01:46:12 -08:00
committed by GitHub
parent ff7424f491
commit b02fd288b2
3 changed files with 20 additions and 6 deletions

View File

@@ -404,6 +404,11 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
@@ -423,10 +428,6 @@ class LlamaModel(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue