[Quantization] Enable BNB support for InternS1 (#21953)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -34,7 +34,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.models import is_pooling_model
|
||||
from vllm.model_executor.utils import (get_packed_modules_mapping,
|
||||
from vllm.model_executor.utils import (get_moe_expert_mapping,
|
||||
get_packed_modules_mapping,
|
||||
set_weight_attrs)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -43,6 +44,12 @@ from vllm.platforms import current_platform
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_moe_model(model: torch.nn.Module) -> bool:
|
||||
"""Checks if the model contains FusedMoE layers."""
|
||||
return bool(any(
|
||||
isinstance(module, FusedMoE) for module in model.modules()))
|
||||
|
||||
|
||||
class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
"""Model loader to load model weights with BitAndBytes quantization."""
|
||||
|
||||
@@ -61,6 +68,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# Store all module names (from transformers) that support
|
||||
# BNB quantization.
|
||||
self.target_modules: list[str] = []
|
||||
# Store the mapping of expert parameters for MoE models.
|
||||
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
|
||||
# mapping weight names from transformers to vllm.
|
||||
self.weight_mapper: Callable = lambda name: name
|
||||
self.pre_quant: bool = False
|
||||
@@ -413,13 +422,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# in case model has a mixture of disk-merged and disk-split
|
||||
# weights with same last name.
|
||||
self.target_modules.append(name)
|
||||
elif (isinstance(module, FusedMoE)
|
||||
and hasattr(module.quant_method, "quant_config")):
|
||||
if not hasattr(model, "get_expert_mapping"):
|
||||
raise AttributeError(
|
||||
f"MoE Model {type(model).__name__} does not support "
|
||||
"BitsAndBytes quantization yet. Ensure this model has "
|
||||
"'get_expert_mapping' method.")
|
||||
elif isinstance(module, FusedMoE) and hasattr(
|
||||
module.quant_method, "quant_config"):
|
||||
# TODO: support FusedMoE with prequant and 8bit.
|
||||
if self.pre_quant:
|
||||
raise ValueError(
|
||||
@@ -430,9 +434,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
"BitsAndBytes 8bit quantization with FusedMoE is not "
|
||||
"supported yet.")
|
||||
# Get the corresponding weight name using module name and
|
||||
# get_expert_mapping.
|
||||
expert_mapping = model.get_expert_mapping()
|
||||
for exp in expert_mapping:
|
||||
# expert_params_mapping.
|
||||
|
||||
for exp in self.expert_params_mapping:
|
||||
weight_name = exp[1]
|
||||
rep_name = name.replace("experts",
|
||||
"") + weight_name.removesuffix(".")
|
||||
@@ -464,7 +468,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
elif isinstance(module, (RowParallelLinear, )):
|
||||
self.column_sharded_weights_modules.append(name)
|
||||
elif isinstance(module, FusedMoE):
|
||||
expert_mapping = model.get_expert_mapping()
|
||||
expert_mapping = self.expert_params_mapping
|
||||
for exp in expert_mapping:
|
||||
if exp[-1] == "w2":
|
||||
weight_name = exp[1]
|
||||
@@ -516,6 +520,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
self.is_pool_model = is_pooling_model(model)
|
||||
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
|
||||
|
||||
if is_moe_model(model):
|
||||
self.expert_params_mapping = get_moe_expert_mapping(model)
|
||||
if not self.expert_params_mapping:
|
||||
raise AttributeError(
|
||||
f"MoE Model {type(model).__name__} does not support "
|
||||
"BitsAndBytes quantization yet. Ensure this model has "
|
||||
"'get_expert_mapping' method.")
|
||||
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
||||
# to ensure correct loading of weights.
|
||||
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
||||
@@ -569,10 +580,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
if not hasattr(model, "get_expert_mapping"):
|
||||
if not self.expert_params_mapping:
|
||||
return dict()
|
||||
|
||||
expert_mapping = model.get_expert_mapping()
|
||||
expert_mapping = self.expert_params_mapping
|
||||
expert_qs_dict = {}
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance(module, FusedMoE):
|
||||
|
||||
Reference in New Issue
Block a user