[Model] Add LoRA support for Whisper models (#29856)

Signed-off-by: daje0601 <englishmt4118@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
daje0601
2026-03-05 11:38:25 +09:00
committed by GitHub
parent 2f4226fe52
commit 3b23d57c96
4 changed files with 185 additions and 14 deletions

View File

@@ -31,6 +31,7 @@ from vllm.model_executor.layers.attention import (
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
@@ -66,6 +67,7 @@ from vllm.v1.attention.backend import (
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsTranscription,
)
@@ -279,11 +281,12 @@ class WhisperCrossAttention(WhisperAttention):
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.kv_proj = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.head_dim,
total_num_heads=0,
total_num_kv_heads=self.total_num_heads,
# Use MergedColumnParallelLinear for K and V projections.
# This enables LoRA support via MergedColumnParallelLinearWithLoRA
# which handles 2-slice configurations.
self.kv_proj = MergedColumnParallelLinear(
input_size=embed_dim,
output_sizes=[embed_dim, embed_dim],
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.kv_proj",
@@ -615,8 +618,9 @@ class WhisperModel(nn.Module):
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
(".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"),
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"),
# MergedColumnParallelLinear uses integer indices (0, 1)
(".encoder_attn.kv_proj", ".encoder_attn.k_proj", 0),
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
@@ -790,14 +794,12 @@ class WhisperForConditionalGeneration(
nn.Module,
SupportsTranscription,
SupportsMultiModal,
SupportsLoRA,
):
# LoRA-specific attributes
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
],
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"kv_proj": ["k_proj", "v_proj"],
}
hf_to_vllm_mapper = WeightsMapper(