[Quantization][1/N] MoE support BNB-Inflight Quantization (#20061)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-07-11 16:01:13 +08:00
committed by GitHub
parent 762be26a8e
commit 8020e98c9f
8 changed files with 561 additions and 88 deletions

View File

@@ -391,6 +391,15 @@ class Qwen2MoeModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
@@ -402,14 +411,6 @@ class Qwen2MoeModel(nn.Module):
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
@@ -441,11 +442,13 @@ class Qwen2MoeModel(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
for mapping in self.get_expert_mapping():
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if "layers.13.mlp.experts.w2_weight" in name:
pass
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
@@ -493,6 +496,17 @@ class Qwen2MoeModel(nn.Module):
class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -538,3 +552,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()