[Quantization] Enable BNB support for more MoE models (#21100)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-07-19 08:52:02 +08:00
committed by GitHub
parent 217937221b
commit 466e878f2a
5 changed files with 223 additions and 181 deletions

View File

@@ -360,6 +360,16 @@ class Grok1Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Map Grok1's unique expert parameter names to standard names
# Grok1 uses "num_experts" in its config
num_experts = getattr(self.config, "num_experts", 8)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="linear", # Grok1 specific
ckpt_down_proj_name="linear_1", # Grok1 specific
ckpt_up_proj_name="linear_v", # Grok1 specific
num_experts=num_experts)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
@@ -369,18 +379,9 @@ class Grok1Model(nn.Module):
("qkv_proj", "v_proj", "v"),
]
# Map Grok1's unique expert parameter names to standard names
# Grok1 uses "num_experts" in its config
num_experts = getattr(self.config, "num_experts", 8)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="linear", # Grok1 specific
ckpt_down_proj_name="linear_1", # Grok1 specific
ckpt_up_proj_name="linear_v", # Grok1 specific
num_experts=num_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
@@ -544,3 +545,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
skip_prefixes=skip_prefixes,
)
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()