[Misc] Add BNB quantization for Whisper (#12381)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -638,6 +638,19 @@ def input_mapper_for_whisper(
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"audio", get_max_whisper_audio_tokens)
|
||||
class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
packed_modules_mapping = {
|
||||
"self_attn.qkv_proj": [
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
],
|
||||
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
|
||||
}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
|
||||
".fc1.": ".mlp.fc1.",
|
||||
".fc2.": ".mlp.fc2."
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@@ -731,10 +744,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
|
||||
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})
|
||||
|
||||
# add fake zeros bias for k_proj to state_dict
|
||||
weights = _create_fake_bias_for_k_proj(weights)
|
||||
return loader.load_weights(weights, mapper=mapper)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
def _create_fake_bias_for_k_proj(
|
||||
|
||||
Reference in New Issue
Block a user