[Quantization] Enable BNB support for more MoE models (#21100)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -53,7 +53,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
@@ -374,6 +374,14 @@ class BailingMoeModel(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
@@ -381,14 +389,10 @@ class BailingMoeModel(nn.Module):
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if self.config.norm_head and "lm_head.weight" in name:
|
||||
loaded_weight = F.normalize(loaded_weight,
|
||||
@@ -449,7 +453,7 @@ class BailingMoeModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class BailingMoeForCausalLM(nn.Module, SupportsPP):
|
||||
class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
@@ -518,3 +522,6 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP):
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
Reference in New Issue
Block a user