[Quantization] Enable BNB support for more MoE models (#21100)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-07-19 08:52:02 +08:00
committed by GitHub
parent 217937221b
commit 466e878f2a
5 changed files with 223 additions and 181 deletions

View File

@@ -53,7 +53,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@@ -374,6 +374,14 @@ class BailingMoeModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
@@ -381,14 +389,10 @@ class BailingMoeModel(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if self.config.norm_head and "lm_head.weight" in name:
loaded_weight = F.normalize(loaded_weight,
@@ -449,7 +453,7 @@ class BailingMoeModel(nn.Module):
return loaded_params
class BailingMoeForCausalLM(nn.Module, SupportsPP):
class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
@@ -518,3 +522,6 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP):
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()