[Model] Add LoRA support for Whisper models (#29856)
Signed-off-by: daje0601 <englishmt4118@gmail.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -31,6 +31,7 @@ from vllm.model_executor.layers.attention import (
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
@@ -66,6 +67,7 @@ from vllm.v1.attention.backend import (
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsTranscription,
|
||||
)
|
||||
@@ -279,11 +281,12 @@ class WhisperCrossAttention(WhisperAttention):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
)
|
||||
self.kv_proj = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=0,
|
||||
total_num_kv_heads=self.total_num_heads,
|
||||
# Use MergedColumnParallelLinear for K and V projections.
|
||||
# This enables LoRA support via MergedColumnParallelLinearWithLoRA
|
||||
# which handles 2-slice configurations.
|
||||
self.kv_proj = MergedColumnParallelLinear(
|
||||
input_size=embed_dim,
|
||||
output_sizes=[embed_dim, embed_dim],
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_proj",
|
||||
@@ -615,8 +618,9 @@ class WhisperModel(nn.Module):
|
||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||
(".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"),
|
||||
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"),
|
||||
# MergedColumnParallelLinear uses integer indices (0, 1)
|
||||
(".encoder_attn.kv_proj", ".encoder_attn.k_proj", 0),
|
||||
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
@@ -790,14 +794,12 @@ class WhisperForConditionalGeneration(
|
||||
nn.Module,
|
||||
SupportsTranscription,
|
||||
SupportsMultiModal,
|
||||
SupportsLoRA,
|
||||
):
|
||||
# LoRA-specific attributes
|
||||
packed_modules_mapping = {
|
||||
"self_attn.qkv_proj": [
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
],
|
||||
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"kv_proj": ["k_proj", "v_proj"],
|
||||
}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
|
||||
Reference in New Issue
Block a user