[Core] Allow disabling TP sharding for parallel Linear layer (#23024)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -69,6 +69,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# Store all module names (from transformers) that support
|
||||
# BNB quantization.
|
||||
self.target_modules: list[str] = []
|
||||
self.tp_disabled_modules: list[str] = []
|
||||
# Store the mapping of expert parameters for MoE models.
|
||||
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
|
||||
# mapping weight names from transformers to vllm.
|
||||
@@ -322,14 +323,24 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
quant_state_dict) -> Generator:
|
||||
from bitsandbytes.functional import quantize_4bit
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
global_tp_size = get_tensor_model_parallel_world_size()
|
||||
global_tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
|
||||
# override tp_size and tp_rank if the module has disabled TP
|
||||
if any(tp_disabled_module in mapped_weight_name
|
||||
for tp_disabled_module in self.tp_disabled_modules):
|
||||
tp_size = 1
|
||||
tp_rank = 0
|
||||
else:
|
||||
tp_size = global_tp_size
|
||||
tp_rank = global_tp_rank
|
||||
|
||||
if any(target_module in mapped_weight_name
|
||||
for target_module in self.target_modules
|
||||
) and mapped_weight_name.endswith(".weight"):
|
||||
@@ -418,12 +429,16 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# Map vllm's names to transformers's names.
|
||||
rep_name, sub_modules = modules_info
|
||||
for sub_name in sub_modules:
|
||||
self.target_modules.append(
|
||||
name.replace(rep_name, sub_name))
|
||||
new_name = name.replace(rep_name, sub_name)
|
||||
self.target_modules.append(new_name)
|
||||
if module.disable_tp:
|
||||
self.tp_disabled_modules.append(new_name)
|
||||
# Add original module name even if the module has stacked map,
|
||||
# in case model has a mixture of disk-merged and disk-split
|
||||
# weights with same last name.
|
||||
self.target_modules.append(name)
|
||||
if module.disable_tp:
|
||||
self.tp_disabled_modules.append(name)
|
||||
elif isinstance(module, FusedMoE) and hasattr(
|
||||
module.quant_method, "quant_config"):
|
||||
# TODO: support FusedMoE with prequant and 8bit.
|
||||
|
||||
Reference in New Issue
Block a user