[Quantization] Enable BNB support for InternS1 (#21953)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-08-01 19:09:54 +08:00
committed by GitHub
parent 4931486988
commit 28b18cc741
2 changed files with 43 additions and 16 deletions

View File

@@ -34,7 +34,8 @@ from vllm.model_executor.model_loader.weight_utils import (
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models import is_pooling_model
from vllm.model_executor.utils import (get_packed_modules_mapping,
from vllm.model_executor.utils import (get_moe_expert_mapping,
get_packed_modules_mapping,
set_weight_attrs)
from vllm.platforms import current_platform
@@ -43,6 +44,12 @@ from vllm.platforms import current_platform
logger = init_logger(__name__)
def is_moe_model(model: torch.nn.Module) -> bool:
"""Checks if the model contains FusedMoE layers."""
return bool(any(
isinstance(module, FusedMoE) for module in model.modules()))
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""
@@ -61,6 +68,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: list[str] = []
# Store the mapping of expert parameters for MoE models.
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
# mapping weight names from transformers to vllm.
self.weight_mapper: Callable = lambda name: name
self.pre_quant: bool = False
@@ -413,13 +422,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# in case model has a mixture of disk-merged and disk-split
# weights with same last name.
self.target_modules.append(name)
elif (isinstance(module, FusedMoE)
and hasattr(module.quant_method, "quant_config")):
if not hasattr(model, "get_expert_mapping"):
raise AttributeError(
f"MoE Model {type(model).__name__} does not support "
"BitsAndBytes quantization yet. Ensure this model has "
"'get_expert_mapping' method.")
elif isinstance(module, FusedMoE) and hasattr(
module.quant_method, "quant_config"):
# TODO: support FusedMoE with prequant and 8bit.
if self.pre_quant:
raise ValueError(
@@ -430,9 +434,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
"BitsAndBytes 8bit quantization with FusedMoE is not "
"supported yet.")
# Get the corresponding weight name using module name and
# get_expert_mapping.
expert_mapping = model.get_expert_mapping()
for exp in expert_mapping:
# expert_params_mapping.
for exp in self.expert_params_mapping:
weight_name = exp[1]
rep_name = name.replace("experts",
"") + weight_name.removesuffix(".")
@@ -464,7 +468,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
elif isinstance(module, (RowParallelLinear, )):
self.column_sharded_weights_modules.append(name)
elif isinstance(module, FusedMoE):
expert_mapping = model.get_expert_mapping()
expert_mapping = self.expert_params_mapping
for exp in expert_mapping:
if exp[-1] == "w2":
weight_name = exp[1]
@@ -516,6 +520,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
self.is_pool_model = is_pooling_model(model)
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
if is_moe_model(model):
self.expert_params_mapping = get_moe_expert_mapping(model)
if not self.expert_params_mapping:
raise AttributeError(
f"MoE Model {type(model).__name__} does not support "
"BitsAndBytes quantization yet. Ensure this model has "
"'get_expert_mapping' method.")
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
@@ -569,10 +580,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
"""
from bitsandbytes.functional import QuantState
if not hasattr(model, "get_expert_mapping"):
if not self.expert_params_mapping:
return dict()
expert_mapping = model.get_expert_mapping()
expert_mapping = self.expert_params_mapping
expert_qs_dict = {}
for name, module in model.named_modules():
if not isinstance(module, FusedMoE):