Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,43 +9,58 @@ from typing import Annotated, Literal, Optional, Union, cast
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
||||
WhisperProcessor)
|
||||
from transformers import (
|
||||
BatchFeature,
|
||||
WhisperConfig,
|
||||
WhisperFeatureExtractor,
|
||||
WhisperProcessor,
|
||||
)
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.cross_attention import CrossAttention
|
||||
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
|
||||
VllmConfig)
|
||||
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
|
||||
from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||
SupportsTranscription)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
||||
make_layers, maybe_prefix)
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
cast_overflow_tensors,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -108,7 +123,7 @@ ISO639_1_SUPPORTED_LANGS = {
|
||||
"uk": "Ukrainian",
|
||||
"ur": "Urdu",
|
||||
"vi": "Vietnamese",
|
||||
"cy": "Welsh"
|
||||
"cy": "Welsh",
|
||||
}
|
||||
|
||||
|
||||
@@ -120,8 +135,7 @@ class WhisperAudioInputs(TensorSchema):
|
||||
- t: Time frames (M)
|
||||
"""
|
||||
|
||||
input_features: Annotated[Optional[NestedTensors],
|
||||
TensorShape("b", "nmb", "t")]
|
||||
input_features: Annotated[Optional[NestedTensors], TensorShape("b", "nmb", "t")]
|
||||
|
||||
|
||||
class WhisperEncoderAttention(MultiHeadAttention):
|
||||
@@ -153,7 +167,6 @@ class WhisperEncoderAttention(MultiHeadAttention):
|
||||
|
||||
|
||||
class WhisperPositionalEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_positions: int, embedding_dim: int):
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
|
||||
@@ -162,7 +175,6 @@ class WhisperPositionalEmbedding(nn.Embedding):
|
||||
|
||||
|
||||
class WhisperAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
@@ -196,7 +208,8 @@ class WhisperAttention(nn.Module):
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: "
|
||||
f"{self.embed_dim} and `num_heads`: {num_heads}).")
|
||||
f"{self.embed_dim} and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self._init_qkv(embed_dim, bias, quant_config, prefix=prefix)
|
||||
@@ -269,7 +282,6 @@ class WhisperAttention(nn.Module):
|
||||
|
||||
|
||||
class WhisperCrossAttention(WhisperAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
@@ -336,7 +348,6 @@ class WhisperCrossAttention(WhisperAttention):
|
||||
|
||||
|
||||
class WhisperMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
@@ -369,7 +380,6 @@ class WhisperMLP(nn.Module):
|
||||
|
||||
|
||||
class WhisperEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -414,7 +424,6 @@ class WhisperEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
class WhisperDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -474,48 +483,39 @@ class WhisperDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
class WhisperEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
init_in_fp32: bool = False):
|
||||
def __init__(
|
||||
self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
embed_dim = config.d_model
|
||||
self.num_mel_bins = config.num_mel_bins
|
||||
self.max_source_positions = config.max_source_positions
|
||||
self.embed_scale = (math.sqrt(embed_dim)
|
||||
if config.scale_embedding else 1.0)
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.conv1 = nn.Conv1d(self.num_mel_bins,
|
||||
embed_dim,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
self.conv2 = nn.Conv1d(embed_dim,
|
||||
embed_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.encoder_layers,
|
||||
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.layers"),
|
||||
lambda prefix: WhisperEncoderLayer(
|
||||
vllm_config=vllm_config, prefix=f"{prefix}.layers"
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
maybe_fp32_init_ctx = set_default_torch_dtype(
|
||||
torch.float32) if init_in_fp32 else nullcontext()
|
||||
maybe_fp32_init_ctx = (
|
||||
set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
|
||||
)
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
maybe_fp32_init_ctx,
|
||||
torch.no_grad(),
|
||||
maybe_fp32_init_ctx,
|
||||
):
|
||||
self.embed_positions = nn.Embedding(self.max_source_positions,
|
||||
embed_dim)
|
||||
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
|
||||
self.embed_positions.weight.copy_(
|
||||
sinusoids(*self.embed_positions.weight.shape))
|
||||
sinusoids(*self.embed_positions.weight.shape)
|
||||
)
|
||||
|
||||
def forward(self, input_features: Union[torch.Tensor, list[torch.Tensor]]):
|
||||
hidden_states = []
|
||||
@@ -523,9 +523,9 @@ class WhisperEncoder(nn.Module):
|
||||
embeds = nn.functional.gelu(self.conv1(features))
|
||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||
embeds = embeds.transpose(-1, -2)
|
||||
embeds = (embeds +
|
||||
self.embed_positions.weight[:embeds.size(-2), :]).to(
|
||||
embeds.dtype)
|
||||
embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
|
||||
embeds.dtype
|
||||
)
|
||||
hidden_states.append(embeds)
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
|
||||
@@ -537,7 +537,6 @@ class WhisperEncoder(nn.Module):
|
||||
|
||||
|
||||
class WhisperDecoder(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -545,17 +544,19 @@ class WhisperDecoder(nn.Module):
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_target_positions = config.max_target_positions
|
||||
self.max_source_positions = config.max_source_positions
|
||||
self.embed_scale = (math.sqrt(config.d_model)
|
||||
if config.scale_embedding else 1.0)
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model,
|
||||
self.padding_idx)
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx
|
||||
)
|
||||
self.embed_positions = WhisperPositionalEmbedding(
|
||||
self.max_target_positions, config.d_model)
|
||||
self.max_target_positions, config.d_model
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.decoder_layers,
|
||||
lambda prefix: WhisperDecoderLayer(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.layers"),
|
||||
lambda prefix: WhisperDecoderLayer(
|
||||
vllm_config=vllm_config, prefix=f"{prefix}.layers"
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
@@ -584,13 +585,14 @@ class WhisperDecoder(nn.Module):
|
||||
|
||||
|
||||
class WhisperModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.encoder = WhisperEncoder(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.decoder = WhisperDecoder(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.decoder")
|
||||
self.encoder = WhisperEncoder(
|
||||
vllm_config=vllm_config, prefix=f"{prefix}.encoder"
|
||||
)
|
||||
self.decoder = WhisperDecoder(
|
||||
vllm_config=vllm_config, prefix=f"{prefix}.decoder"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -614,8 +616,7 @@ class WhisperModel(nn.Module):
|
||||
return None
|
||||
return self.encoder(input_features)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||
@@ -645,15 +646,13 @@ class WhisperModel(nn.Module):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class WhisperProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self) -> WhisperConfig:
|
||||
return self.ctx.get_hf_config(WhisperConfig)
|
||||
|
||||
@@ -670,8 +669,7 @@ class WhisperProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": 1}
|
||||
|
||||
def get_feature_extractor(self,
|
||||
**kwargs: object) -> WhisperFeatureExtractor:
|
||||
def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
|
||||
hf_processor = self.get_hf_processor(**kwargs)
|
||||
feature_extractor = hf_processor.feature_extractor # type: ignore
|
||||
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||
@@ -682,7 +680,6 @@ class WhisperProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
|
||||
class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
@@ -703,16 +700,13 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides)
|
||||
"audio": self._get_dummy_audios(
|
||||
length=audio_len, num_audios=num_audios, overrides=audio_overrides
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class WhisperMultiModalProcessor(
|
||||
EncDecMultiModalProcessor[WhisperProcessingInfo]):
|
||||
|
||||
class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo]):
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||
@@ -779,11 +773,14 @@ class WhisperMultiModalProcessor(
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor,
|
||||
info=WhisperProcessingInfo,
|
||||
dummy_inputs=WhisperDummyInputsBuilder)
|
||||
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
SupportsMultiModal):
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
WhisperMultiModalProcessor,
|
||||
info=WhisperProcessingInfo,
|
||||
dummy_inputs=WhisperDummyInputsBuilder,
|
||||
)
|
||||
class WhisperForConditionalGeneration(
|
||||
nn.Module, SupportsTranscription, SupportsMultiModal
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"self_attn.qkv_proj": [
|
||||
"self_attn.q_proj",
|
||||
@@ -793,10 +790,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
|
||||
}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
|
||||
".fc1.": ".mlp.fc1.",
|
||||
".fc2.": ".mlp.fc2."
|
||||
})
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
|
||||
)
|
||||
|
||||
# Whisper only supports audio-conditioned generation.
|
||||
supports_transcription_only = True
|
||||
@@ -811,23 +807,26 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
logger.warning(
|
||||
"Defaulting to language='en'. If you wish to transcribe "
|
||||
"audio in a different language, pass the `language` field "
|
||||
"in the TranscriptionRequest.")
|
||||
"in the TranscriptionRequest."
|
||||
)
|
||||
language = "en"
|
||||
return super().validate_language(language)
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
model_config: ModelConfig, # not needed here
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: Optional[str],
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str]) -> PromptType:
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
model_config: ModelConfig, # not needed here
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: Optional[str],
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str],
|
||||
) -> PromptType:
|
||||
if language is None:
|
||||
raise ValueError(
|
||||
"Language must be specified when creating the Whisper prompt")
|
||||
"Language must be specified when creating the Whisper prompt"
|
||||
)
|
||||
prompt = {
|
||||
"encoder_prompt": {
|
||||
# Whisper does not support encoder prompt.
|
||||
@@ -836,10 +835,11 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
"audio": (audio, stt_config.sample_rate),
|
||||
},
|
||||
},
|
||||
"decoder_prompt":
|
||||
((f"<|prev|>{request_prompt}" if request_prompt else "") +
|
||||
f"<|startoftranscript|><|{language}|>" +
|
||||
f"<|{task_type}|><|notimestamps|>")
|
||||
"decoder_prompt": (
|
||||
(f"<|prev|>{request_prompt}" if request_prompt else "")
|
||||
+ f"<|startoftranscript|><|{language}|>"
|
||||
+ f"<|{task_type}|><|notimestamps|>"
|
||||
),
|
||||
}
|
||||
return cast(PromptType, prompt)
|
||||
|
||||
@@ -851,8 +851,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
raise ValueError("Only audio modality is supported")
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(cls, model_config: ModelConfig,
|
||||
task_type: str) -> SpeechToTextConfig:
|
||||
def get_speech_to_text_config(
|
||||
cls, model_config: ModelConfig, task_type: str
|
||||
) -> SpeechToTextConfig:
|
||||
processor = cached_get_processor(model_config.model)
|
||||
|
||||
return SpeechToTextConfig(
|
||||
@@ -861,9 +862,12 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_num_audio_tokens(cls, audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig) -> Optional[int]:
|
||||
def get_num_audio_tokens(
|
||||
cls,
|
||||
audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
) -> Optional[int]:
|
||||
processor = cached_get_processor(model_config.model)
|
||||
hop_length = processor.feature_extractor.hop_length
|
||||
assert hop_length is not None
|
||||
@@ -871,8 +875,7 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
# prompts directly at least not to Whisper.
|
||||
# One indicator of the encoder amount of processing
|
||||
# is the log-mel spectogram length.
|
||||
return math.ceil(audio_duration_s * stt_config.sample_rate /
|
||||
hop_length)
|
||||
return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@@ -883,15 +886,17 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
|
||||
self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.proj_out = ParallelLMHead(config.vocab_size,
|
||||
config.d_model,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "proj_out"))
|
||||
self.proj_out = self.proj_out.tie_weights(
|
||||
self.model.decoder.embed_tokens)
|
||||
self.proj_out = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "proj_out"),
|
||||
)
|
||||
self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens)
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size, logit_scale)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, config.vocab_size, logit_scale
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -910,8 +915,7 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model.decoder
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
# Required as part of SupportsMultiModal interface.
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
return [self.model.get_encoder_outputs(audio_input["input_features"])]
|
||||
@@ -928,16 +932,16 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
# Whisper does not have encoder text tokens.
|
||||
return self.model.decoder.get_input_embeddings(input_ids)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> WhisperAudioInputs:
|
||||
def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs:
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
|
||||
if input_features is not None:
|
||||
if not isinstance(input_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio features. "
|
||||
f"Got type: {type(input_features)}")
|
||||
input_features = torch.cat(
|
||||
[feat.to(self.dtype) for feat in input_features])
|
||||
raise ValueError(
|
||||
"Incorrect type of audio features. "
|
||||
f"Got type: {type(input_features)}"
|
||||
)
|
||||
input_features = torch.cat([feat.to(self.dtype) for feat in input_features])
|
||||
|
||||
return WhisperAudioInputs(input_features=input_features)
|
||||
|
||||
@@ -945,8 +949,7 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
logits = self.logits_processor(self.proj_out, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
|
||||
|
||||
# add fake zeros bias for k_proj to state_dict
|
||||
@@ -955,7 +958,7 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
|
||||
|
||||
def _create_fake_bias_for_k_proj(
|
||||
weights: Iterable[tuple[str, torch.Tensor]]
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Create full zeros bias for k_proj weight in self-attn and x-attn layers.
|
||||
|
||||
Reference in New Issue
Block a user