[Quantization] Enable BNB support for more MoE models (#21100)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -360,6 +360,16 @@ class Grok1Model(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Map Grok1's unique expert parameter names to standard names
|
||||
# Grok1 uses "num_experts" in its config
|
||||
num_experts = getattr(self.config, "num_experts", 8)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="linear", # Grok1 specific
|
||||
ckpt_down_proj_name="linear_1", # Grok1 specific
|
||||
ckpt_up_proj_name="linear_v", # Grok1 specific
|
||||
num_experts=num_experts)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
@@ -369,18 +379,9 @@ class Grok1Model(nn.Module):
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
# Map Grok1's unique expert parameter names to standard names
|
||||
# Grok1 uses "num_experts" in its config
|
||||
num_experts = getattr(self.config, "num_experts", 8)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="linear", # Grok1 specific
|
||||
ckpt_down_proj_name="linear_1", # Grok1 specific
|
||||
ckpt_up_proj_name="linear_v", # Grok1 specific
|
||||
num_experts=num_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
@@ -544,3 +545,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
skip_prefixes=skip_prefixes,
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
Reference in New Issue
Block a user