[Misc] Fused MoE Marlin support for GPTQ (#8217)

This commit is contained in:
Dipika Sikka
2024-09-09 23:02:52 -04:00
committed by GitHub
parent c7cb5c3335
commit 6cd5e5b07e
19 changed files with 912 additions and 204 deletions

View File

@@ -435,7 +435,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
@@ -454,6 +455,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
@@ -464,7 +468,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):