[Core] Allow disabling TP sharding for parallel Linear layer (#23024)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Isotr0py
2025-09-06 13:53:58 +08:00
committed by GitHub
parent 6432739ef1
commit 53b19ccdd5
7 changed files with 203 additions and 280 deletions

View File

@@ -69,6 +69,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: list[str] = []
self.tp_disabled_modules: list[str] = []
# Store the mapping of expert parameters for MoE models.
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
# mapping weight names from transformers to vllm.
@@ -322,14 +323,24 @@ class BitsAndBytesModelLoader(BaseModelLoader):
quant_state_dict) -> Generator:
from bitsandbytes.functional import quantize_4bit
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
global_tp_size = get_tensor_model_parallel_world_size()
global_tp_rank = get_tensor_model_parallel_rank()
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
# override tp_size and tp_rank if the module has disabled TP
if any(tp_disabled_module in mapped_weight_name
for tp_disabled_module in self.tp_disabled_modules):
tp_size = 1
tp_rank = 0
else:
tp_size = global_tp_size
tp_rank = global_tp_rank
if any(target_module in mapped_weight_name
for target_module in self.target_modules
) and mapped_weight_name.endswith(".weight"):
@@ -418,12 +429,16 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Map vllm's names to transformers's names.
rep_name, sub_modules = modules_info
for sub_name in sub_modules:
self.target_modules.append(
name.replace(rep_name, sub_name))
new_name = name.replace(rep_name, sub_name)
self.target_modules.append(new_name)
if module.disable_tp:
self.tp_disabled_modules.append(new_name)
# Add original module name even if the module has stacked map,
# in case model has a mixture of disk-merged and disk-split
# weights with same last name.
self.target_modules.append(name)
if module.disable_tp:
self.tp_disabled_modules.append(name)
elif isinstance(module, FusedMoE) and hasattr(
module.quant_method, "quant_config"):
# TODO: support FusedMoE with prequant and 8bit.