[Misc] Add BNB quantization for Whisper (#12381)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -803,9 +803,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
iterator = safetensors_weights_iterator(hf_weights_files)
|
iterator = safetensors_weights_iterator(hf_weights_files)
|
||||||
else:
|
else:
|
||||||
iterator = pt_weights_iterator(hf_weights_files)
|
iterator = pt_weights_iterator(hf_weights_files)
|
||||||
for name, param in iterator:
|
for org_name, param in iterator:
|
||||||
# mapping weight names from transformers to vllm.
|
# mapping weight names from transformers to vllm while preserving
|
||||||
yield self.weight_mapper(name), param
|
# original names.
|
||||||
|
mapped_name = self.weight_mapper(org_name)
|
||||||
|
yield org_name, mapped_name, param
|
||||||
|
|
||||||
def _get_quantized_weights_iterator(
|
def _get_quantized_weights_iterator(
|
||||||
self,
|
self,
|
||||||
@@ -866,24 +868,30 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
|
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
|
||||||
quant_state_dict) -> Generator:
|
quant_state_dict) -> Generator:
|
||||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
for (
|
||||||
hf_weights_files, use_safetensors):
|
org_weight_name,
|
||||||
if not weight_name.lower().endswith(".scb"):
|
mapped_weight_name,
|
||||||
|
weight_tensor,
|
||||||
|
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||||
|
if not mapped_weight_name.lower().endswith(".scb"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
weight_key = weight_name.lower().replace(".scb", ".weight")
|
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
|
||||||
quant_state_dict[weight_key] = weight_tensor
|
quant_state_dict[weight_key] = weight_tensor
|
||||||
|
|
||||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
for (
|
||||||
hf_weights_files, use_safetensors):
|
org_weight_name,
|
||||||
if self._is_8bit_weight_name(weight_name):
|
mapped_weight_name,
|
||||||
|
weight_tensor,
|
||||||
|
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||||
|
if self._is_8bit_weight_name(mapped_weight_name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if weight_name in quant_state_dict:
|
if mapped_weight_name in quant_state_dict:
|
||||||
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
||||||
yield weight_name, weight_tensor
|
yield org_weight_name, weight_tensor
|
||||||
else:
|
else:
|
||||||
yield weight_name, weight_tensor
|
yield org_weight_name, weight_tensor
|
||||||
|
|
||||||
def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
|
def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
|
||||||
quant_state_dict) -> Generator:
|
quant_state_dict) -> Generator:
|
||||||
@@ -893,15 +901,19 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
weight_iterator = self._hf_weight_iter(hf_weights_files,
|
weight_iterator = self._hf_weight_iter(hf_weights_files,
|
||||||
use_safetensors)
|
use_safetensors)
|
||||||
temp_state_dict = {}
|
temp_state_dict = {}
|
||||||
for weight_name, weight_tensor in weight_iterator:
|
for (
|
||||||
if not self._is_4bit_weight_name(weight_name):
|
org_weight_name,
|
||||||
|
mapped_weight_name,
|
||||||
|
weight_tensor,
|
||||||
|
) in weight_iterator:
|
||||||
|
if not self._is_4bit_weight_name(mapped_weight_name):
|
||||||
continue
|
continue
|
||||||
# bitsandbytes library requires
|
# bitsandbytes library requires
|
||||||
# weight.quant_state.bitsandbytes__* in CPU
|
# weight.quant_state.bitsandbytes__* in CPU
|
||||||
if "quant_state.bitsandbytes" in weight_name:
|
if "quant_state.bitsandbytes" in mapped_weight_name:
|
||||||
temp_state_dict[weight_name] = weight_tensor.cpu().data
|
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
|
||||||
else:
|
else:
|
||||||
temp_state_dict[weight_name] = weight_tensor
|
temp_state_dict[mapped_weight_name] = weight_tensor
|
||||||
|
|
||||||
# Closure to parse quant_state for each prequant weight
|
# Closure to parse quant_state for each prequant weight
|
||||||
def _parse_quant_state(param_name: str,
|
def _parse_quant_state(param_name: str,
|
||||||
@@ -915,20 +927,24 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
# Second iterate over all prequant and normal weights
|
# Second iterate over all prequant and normal weights
|
||||||
# pre quantized weights would have a quant_state
|
# pre quantized weights would have a quant_state
|
||||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
for (
|
||||||
hf_weights_files, use_safetensors):
|
org_weight_name,
|
||||||
if self._is_4bit_weight_name(weight_name):
|
mapped_weight_name,
|
||||||
|
weight_tensor,
|
||||||
|
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||||
|
if self._is_4bit_weight_name(mapped_weight_name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (f"{weight_name}.quant_state.bitsandbytes__nf4"
|
if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4"
|
||||||
in temp_state_dict) or (
|
in temp_state_dict) or (
|
||||||
f"{weight_name}.quant_state.bitsandbytes__fp4"
|
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4"
|
||||||
in temp_state_dict):
|
in temp_state_dict):
|
||||||
quant_state = _parse_quant_state(weight_name, temp_state_dict)
|
quant_state = _parse_quant_state(mapped_weight_name,
|
||||||
quant_state_dict[weight_name] = quant_state
|
temp_state_dict)
|
||||||
yield weight_name, weight_tensor
|
quant_state_dict[mapped_weight_name] = quant_state
|
||||||
|
yield org_weight_name, weight_tensor
|
||||||
else:
|
else:
|
||||||
yield weight_name, weight_tensor
|
yield org_weight_name, weight_tensor
|
||||||
|
|
||||||
def _unquantized_generator(self, hf_weights_files, use_safetensors,
|
def _unquantized_generator(self, hf_weights_files, use_safetensors,
|
||||||
quant_state_dict) -> Generator:
|
quant_state_dict) -> Generator:
|
||||||
@@ -937,18 +953,22 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
for (
|
||||||
hf_weights_files, use_safetensors):
|
org_weight_name,
|
||||||
if any(target_module in weight_name for target_module in
|
mapped_weight_name,
|
||||||
self.target_modules) and weight_name.endswith(".weight"):
|
weight_tensor,
|
||||||
|
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||||
|
if any(target_module in mapped_weight_name
|
||||||
|
for target_module in self.target_modules
|
||||||
|
) and mapped_weight_name.endswith(".weight"):
|
||||||
# Without sharding
|
# Without sharding
|
||||||
if any(
|
if any(
|
||||||
weight_name.startswith(module)
|
mapped_weight_name.startswith(module)
|
||||||
for module in self.unsharded_weights_modules):
|
for module in self.unsharded_weights_modules):
|
||||||
weight_sub_tensor = weight_tensor
|
weight_sub_tensor = weight_tensor
|
||||||
# Shard by column
|
# Shard by column
|
||||||
elif any(
|
elif any(
|
||||||
weight_name.startswith(module)
|
mapped_weight_name.startswith(module)
|
||||||
for module in self.column_sharded_weights_modules):
|
for module in self.column_sharded_weights_modules):
|
||||||
total_size = weight_tensor.size(-1)
|
total_size = weight_tensor.size(-1)
|
||||||
start_index = total_size // tp_size * tp_rank
|
start_index = total_size // tp_size * tp_rank
|
||||||
@@ -958,14 +978,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# Weights have fused on disk. In this case, we assume that the
|
# Weights have fused on disk. In this case, we assume that the
|
||||||
# weight and module use same name.
|
# weight and module use same name.
|
||||||
elif any(
|
elif any(
|
||||||
weight_name.startswith(module)
|
mapped_weight_name.startswith(module)
|
||||||
for module in self.maybe_fused_weights_modules):
|
for module in self.maybe_fused_weights_modules):
|
||||||
# special case for fused weights
|
# special case for fused weights
|
||||||
# get the size of each shard weight tensor
|
# get the size of each shard weight tensor
|
||||||
total_shard_sizes = next(
|
total_shard_sizes = next(
|
||||||
(sizes for module, sizes in
|
(sizes for module, sizes in
|
||||||
self.maybe_fused_weights_modules.items()
|
self.maybe_fused_weights_modules.items()
|
||||||
if weight_name.startswith(module)))
|
if mapped_weight_name.startswith(module)))
|
||||||
total_size = weight_tensor.size(0)
|
total_size = weight_tensor.size(0)
|
||||||
assert total_size == sum(total_shard_sizes)
|
assert total_size == sum(total_shard_sizes)
|
||||||
# get the start/end index of each shard weight tensor
|
# get the start/end index of each shard weight tensor
|
||||||
@@ -1008,23 +1028,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
quant_type="nf4",
|
quant_type="nf4",
|
||||||
)
|
)
|
||||||
|
|
||||||
quant_state_dict[weight_name] = quant_state
|
quant_state_dict[mapped_weight_name] = quant_state
|
||||||
else:
|
else:
|
||||||
processed_weight = weight_tensor
|
processed_weight = weight_tensor
|
||||||
|
yield org_weight_name, processed_weight
|
||||||
yield weight_name, processed_weight
|
|
||||||
|
|
||||||
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, (LinearBase, )):
|
if isinstance(module, (LinearBase, )):
|
||||||
last_name = name.split(".")[-1]
|
if modules_info := self.modules_mapping.get_sub_modules(name):
|
||||||
if sub_modules := self.modules_mapping.packed_mapping.get(
|
|
||||||
last_name, []):
|
|
||||||
# Map vllm's names to transformers's names.
|
# Map vllm's names to transformers's names.
|
||||||
|
rep_name, sub_modules = modules_info
|
||||||
for sub_name in sub_modules:
|
for sub_name in sub_modules:
|
||||||
self.target_modules.append(
|
self.target_modules.append(
|
||||||
name.replace(last_name, sub_name))
|
name.replace(rep_name, sub_name))
|
||||||
# Add original module name even if the module has stacked map,
|
# Add original module name even if the module has stacked map,
|
||||||
# in case model has a mixture of disk-merged and disk-splitted
|
# in case model has a mixture of disk-merged and disk-splitted
|
||||||
# weights with same last name.
|
# weights with same last name.
|
||||||
|
|||||||
@@ -131,3 +131,10 @@ class ParamMapping:
|
|||||||
packed_name,
|
packed_name,
|
||||||
index,
|
index,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_sub_modules(self,
|
||||||
|
module_name: str) -> Optional[Tuple[str, List[str]]]:
|
||||||
|
for key, value in self.packed_mapping.items():
|
||||||
|
if module_name.endswith(key):
|
||||||
|
return key, value
|
||||||
|
return None
|
||||||
|
|||||||
@@ -638,6 +638,19 @@ def input_mapper_for_whisper(
|
|||||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||||
"audio", get_max_whisper_audio_tokens)
|
"audio", get_max_whisper_audio_tokens)
|
||||||
class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
|
class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"self_attn.qkv_proj": [
|
||||||
|
"self_attn.q_proj",
|
||||||
|
"self_attn.k_proj",
|
||||||
|
"self_attn.v_proj",
|
||||||
|
],
|
||||||
|
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
|
||||||
|
}
|
||||||
|
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
|
||||||
|
".fc1.": ".mlp.fc1.",
|
||||||
|
".fc2.": ".mlp.fc2."
|
||||||
|
})
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -731,10 +744,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
|
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
|
||||||
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})
|
|
||||||
# add fake zeros bias for k_proj to state_dict
|
# add fake zeros bias for k_proj to state_dict
|
||||||
weights = _create_fake_bias_for_k_proj(weights)
|
weights = _create_fake_bias_for_k_proj(weights)
|
||||||
return loader.load_weights(weights, mapper=mapper)
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
|
|
||||||
def _create_fake_bias_for_k_proj(
|
def _create_fake_bias_for_k_proj(
|
||||||
|
|||||||
Reference in New Issue
Block a user