[Misc] Add BNB quantization for Whisper (#12381)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-02-04 16:27:36 +08:00
committed by GitHub
parent c36ac98d01
commit 96b23621c1
3 changed files with 82 additions and 44 deletions

View File

@@ -638,6 +638,19 @@ def input_mapper_for_whisper(
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_whisper_audio_tokens)
class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
],
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
}
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
".fc1.": ".mlp.fc1.",
".fc2.": ".mlp.fc2."
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -731,10 +744,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})
# add fake zeros bias for k_proj to state_dict
weights = _create_fake_bias_for_k_proj(weights)
return loader.load_weights(weights, mapper=mapper)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def _create_fake_bias_for_k_proj(