[Voxtral] Fix speech transcription api (#31388)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: bk-201 <joy25810@foxmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: bk-201 <joy25810@foxmail.com> Co-authored-by: prashanth058 <prashanth.dannamaneni@uipath.com> Co-authored-by: Anexdeus <5142168@mail.ru> Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
committed by
GitHub
parent
2972a05473
commit
18d4e481d0
@@ -17,10 +17,11 @@ class SpeechToTextConfig:
|
|||||||
16kHz audio input. The input audio will be automatically resampled to this
|
16kHz audio input. The input audio will be automatically resampled to this
|
||||||
rate before processing."""
|
rate before processing."""
|
||||||
|
|
||||||
max_audio_clip_s: int = 30
|
max_audio_clip_s: int | None = 30
|
||||||
"""Maximum duration in seconds for a single audio clip without chunking.
|
"""Maximum duration in seconds for a single audio clip without chunking.
|
||||||
Audio longer than this will be split into smaller chunks if
|
Audio longer than this will be split into smaller chunks if
|
||||||
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
|
`allow_audio_chunking` evaluates to True, otherwise it will be rejected.
|
||||||
|
`None` means audio duration can be unlimited and won't be chunked."""
|
||||||
|
|
||||||
overlap_chunk_second: int = 1
|
overlap_chunk_second: int = 1
|
||||||
"""Overlap duration in seconds between consecutive audio chunks when
|
"""Overlap duration in seconds between consecutive audio chunks when
|
||||||
|
|||||||
@@ -477,7 +477,15 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
}
|
}
|
||||||
segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
|
segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
|
||||||
text = ""
|
text = ""
|
||||||
|
chunk_size_in_s = self.asr_config.max_audio_clip_s
|
||||||
|
if chunk_size_in_s is None:
|
||||||
|
assert len(list_result_generator) == 1, (
|
||||||
|
"`max_audio_clip_s` is set to None, audio cannot be chunked"
|
||||||
|
)
|
||||||
for idx, result_generator in enumerate(list_result_generator):
|
for idx, result_generator in enumerate(list_result_generator):
|
||||||
|
start_time = (
|
||||||
|
float(idx * chunk_size_in_s) if chunk_size_in_s is not None else 0.0
|
||||||
|
)
|
||||||
async for op in result_generator:
|
async for op in result_generator:
|
||||||
if request.response_format == "verbose_json":
|
if request.response_format == "verbose_json":
|
||||||
segments: list[SpeechToTextSegment] = (
|
segments: list[SpeechToTextSegment] = (
|
||||||
@@ -485,7 +493,7 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
tokens=tuple(op.outputs[0].token_ids),
|
tokens=tuple(op.outputs[0].token_ids),
|
||||||
segment_class=segment_class,
|
segment_class=segment_class,
|
||||||
request=request,
|
request=request,
|
||||||
start_time=idx * self.asr_config.max_audio_clip_s,
|
start_time=start_time,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -653,6 +661,10 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
def _split_audio(
|
def _split_audio(
|
||||||
self, audio_data: np.ndarray, sample_rate: int
|
self, audio_data: np.ndarray, sample_rate: int
|
||||||
) -> list[np.ndarray]:
|
) -> list[np.ndarray]:
|
||||||
|
assert self.asr_config.max_audio_clip_s is not None, (
|
||||||
|
f"{self.asr_config.max_audio_clip_s=} cannot be None to"
|
||||||
|
" split audio into chunks."
|
||||||
|
)
|
||||||
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
|
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
|
||||||
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
|
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChu
|
|||||||
from mistral_common.protocol.instruct.messages import UserMessage
|
from mistral_common.protocol.instruct.messages import UserMessage
|
||||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
from mistral_common.protocol.transcription.request import TranscriptionRequest
|
from mistral_common.protocol.transcription.request import TranscriptionRequest
|
||||||
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
|
from mistral_common.tokens.tokenizers.audio import (
|
||||||
|
Audio,
|
||||||
|
AudioEncoder,
|
||||||
|
TranscriptionFormat,
|
||||||
|
)
|
||||||
from transformers import BatchFeature, TensorType, WhisperConfig
|
from transformers import BatchFeature, TensorType, WhisperConfig
|
||||||
from transformers.tokenization_utils_base import TextInput
|
from transformers.tokenization_utils_base import TextInput
|
||||||
|
|
||||||
@@ -157,13 +161,17 @@ class VoxtralProcessorAdapter:
|
|||||||
|
|
||||||
# pad if necessary
|
# pad if necessary
|
||||||
# TODO(Patrick) - remove once mistral-common is bumped
|
# TODO(Patrick) - remove once mistral-common is bumped
|
||||||
sig = inspect.signature(self._audio_processor.pad)
|
if (
|
||||||
if "is_online_streaming" in sig.parameters:
|
self._audio_processor.audio_config.transcription_format
|
||||||
audio = self._audio_processor.pad(
|
!= TranscriptionFormat.STREAMING
|
||||||
audio, self.sampling_rate, is_online_streaming=False
|
):
|
||||||
)
|
sig = inspect.signature(self._audio_processor.pad)
|
||||||
else:
|
if "is_online_streaming" in sig.parameters:
|
||||||
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
audio = self._audio_processor.pad(
|
||||||
|
audio, self.sampling_rate, is_online_streaming=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||||
|
|
||||||
audio_tokens = [self.begin_audio_token_id] + [
|
audio_tokens = [self.begin_audio_token_id] + [
|
||||||
self.audio_token_id
|
self.audio_token_id
|
||||||
|
|||||||
@@ -3,10 +3,19 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from typing import Literal, cast
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from mistral_common.protocol.instruct.chunk import RawAudio
|
||||||
|
from mistral_common.protocol.transcription.request import (
|
||||||
|
StreamingMode,
|
||||||
|
TranscriptionRequest,
|
||||||
|
)
|
||||||
|
from mistral_common.tokens.tokenizers.audio import Audio
|
||||||
|
|
||||||
from vllm.config.vllm import VllmConfig
|
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||||
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||||
from vllm.model_executor.models.voxtral import (
|
from vllm.model_executor.models.voxtral import (
|
||||||
@@ -27,6 +36,7 @@ from vllm.multimodal.processing import (
|
|||||||
)
|
)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.tokenizers import cached_tokenizer_from_config
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
_flatten_embeddings,
|
_flatten_embeddings,
|
||||||
@@ -205,13 +215,17 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
|
|||||||
"For streaming you must provide an audio input at every step."
|
"For streaming you must provide an audio input at every step."
|
||||||
)
|
)
|
||||||
|
|
||||||
multiple_of = self.audio_config.raw_audio_length_per_tok
|
def _truncate_left(
|
||||||
assert all(
|
sample: torch.Tensor, mult_of: int, pos: int
|
||||||
(this_audio := audio.shape[0]) % multiple_of == 0 for audio in audio_inputs
|
) -> torch.Tensor:
|
||||||
), (
|
assert pos in [0, 1], pos
|
||||||
f"Every input audio waveform has to be a multiple of {multiple_of}, but"
|
if (ctx := sample.shape[pos] % mult_of) != 0:
|
||||||
f" one is {this_audio} with {(this_audio / multiple_of)=}."
|
sample = sample[ctx:] if pos == 0 else sample[:, ctx:]
|
||||||
)
|
assert sample.shape[pos] > 0, (
|
||||||
|
f"Sample is empty after truncation with ctx {ctx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
mel_features = [
|
mel_features = [
|
||||||
self.whisper_encoder.compute_whisper_melspec(audio).to(
|
self.whisper_encoder.compute_whisper_melspec(audio).to(
|
||||||
@@ -219,11 +233,16 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
|
|||||||
)
|
)
|
||||||
for audio in audio_inputs
|
for audio in audio_inputs
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# we truncate the left most mel feature
|
||||||
|
# if the sequence length in impair
|
||||||
|
mel_features = [_truncate_left(mel, 2, 1) for mel in mel_features]
|
||||||
|
|
||||||
seq_lens = [mel.shape[1] for mel in mel_features]
|
seq_lens = [mel.shape[1] for mel in mel_features]
|
||||||
# [total_num_20ms_frames, hidden_size]
|
# [total_num_20ms_frames, hidden_size]
|
||||||
audio_embeddings = self.whisper_encoder.whisper_encoder.forward_conv(
|
audio_embeddings = self.whisper_encoder.whisper_encoder.forward_conv(
|
||||||
mel_features
|
mel_features
|
||||||
)[0]
|
)
|
||||||
conv_stride = self.whisper_encoder.whisper_encoder.total_stride
|
conv_stride = self.whisper_encoder.whisper_encoder.total_stride
|
||||||
audio_embeddings_per_sample = audio_embeddings.split(
|
audio_embeddings_per_sample = audio_embeddings.split(
|
||||||
[s // conv_stride for s in seq_lens], dim=0
|
[s // conv_stride for s in seq_lens], dim=0
|
||||||
@@ -231,13 +250,55 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
|
|||||||
|
|
||||||
# audio_embeddings per sample need to be divisible by 4
|
# audio_embeddings per sample need to be divisible by 4
|
||||||
pool_size = self.config.audio_config.block_pool_size
|
pool_size = self.config.audio_config.block_pool_size
|
||||||
assert all(
|
|
||||||
(this_shape := sample.shape[0]) % pool_size == 0
|
audio_embeddings_per_sample = [
|
||||||
|
_truncate_left(sample, pool_size, 0)
|
||||||
for sample in audio_embeddings_per_sample
|
for sample in audio_embeddings_per_sample
|
||||||
), f"Every audio embedding has to be a multiple of 4, but one is {this_shape}."
|
]
|
||||||
|
|
||||||
audio_embeddings_per_sample = [
|
audio_embeddings_per_sample = [
|
||||||
e.view(e.shape[0] // pool_size, e.shape[1] * pool_size)
|
e.view(e.shape[0] // pool_size, e.shape[1] * pool_size)
|
||||||
for e in audio_embeddings_per_sample
|
for e in audio_embeddings_per_sample
|
||||||
]
|
]
|
||||||
return audio_embeddings_per_sample
|
return audio_embeddings_per_sample
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_speech_to_text_config(
|
||||||
|
cls, model_config: ModelConfig, task_type: str
|
||||||
|
) -> SpeechToTextConfig:
|
||||||
|
tokenizer = cached_tokenizer_from_config(model_config)
|
||||||
|
audio_config = tokenizer.instruct.audio_encoder.audio_config
|
||||||
|
sample_rate = audio_config.sampling_rate
|
||||||
|
return SpeechToTextConfig(
|
||||||
|
max_audio_clip_s=None, # only limited by memory
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
min_energy_split_window_size=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
# for speech-to-text transcription
|
||||||
|
def get_generation_prompt(
|
||||||
|
cls,
|
||||||
|
audio: np.ndarray,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
stt_config: SpeechToTextConfig,
|
||||||
|
language: str | None,
|
||||||
|
task_type: Literal["transcribe", "translate"],
|
||||||
|
request_prompt: str,
|
||||||
|
to_language: str | None,
|
||||||
|
) -> PromptType:
|
||||||
|
tokenizer = cached_tokenizer_from_config(model_config)
|
||||||
|
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
|
||||||
|
|
||||||
|
req = TranscriptionRequest(
|
||||||
|
model=model_config.model,
|
||||||
|
audio=RawAudio.from_audio(audio),
|
||||||
|
language=language,
|
||||||
|
streaming=StreamingMode.OFFLINE,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenized = tokenizer.instruct.encode_transcription(req)
|
||||||
|
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
|
||||||
|
prompts_dict = {"multi_modal_data": {"audio": audio}}
|
||||||
|
prompts_dict["prompt_token_ids"] = tokenized.tokens
|
||||||
|
return cast(PromptType, prompts_dict)
|
||||||
|
|||||||
@@ -469,8 +469,10 @@ class WhisperEncoder(nn.Module):
|
|||||||
self.max_source_positions = config.max_source_positions
|
self.max_source_positions = config.max_source_positions
|
||||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||||
|
|
||||||
is_causal = getattr(config, "is_causal", False)
|
self.is_causal = getattr(config, "is_causal", False)
|
||||||
Conv1d = WhisperCausalConv1d if is_causal else partial(nn.Conv1d, padding=1)
|
Conv1d = (
|
||||||
|
WhisperCausalConv1d if self.is_causal else partial(nn.Conv1d, padding=1)
|
||||||
|
)
|
||||||
|
|
||||||
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
|
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
|
||||||
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
|
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
|
||||||
@@ -485,7 +487,7 @@ class WhisperEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||||
|
|
||||||
if is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
|
if self.is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only NOPE position embeddings are supported "
|
"Only NOPE position embeddings are supported "
|
||||||
f"for causal models, but got {self.pos_embed_type}"
|
f"for causal models, but got {self.pos_embed_type}"
|
||||||
@@ -536,8 +538,11 @@ class WhisperEncoder(nn.Module):
|
|||||||
hidden_states.append(embeds)
|
hidden_states.append(embeds)
|
||||||
input_is_batched = embeds.ndim > 2
|
input_is_batched = embeds.ndim > 2
|
||||||
# Input to MHA must be B x T x D
|
# Input to MHA must be B x T x D
|
||||||
if input_is_batched:
|
if input_is_batched or self.is_causal:
|
||||||
# Models using WhisperEncoder may handle batching internally.
|
# Models using WhisperEncoder may handle batching internally.
|
||||||
|
# If WhisperEncoder is causal, sequences
|
||||||
|
# are not padded to have identical seq length (T)
|
||||||
|
# => concat over feature dim
|
||||||
hidden_states = torch.cat(hidden_states)
|
hidden_states = torch.cat(hidden_states)
|
||||||
else:
|
else:
|
||||||
hidden_states = torch.stack(hidden_states, dim=0)
|
hidden_states = torch.stack(hidden_states, dim=0)
|
||||||
|
|||||||
Reference in New Issue
Block a user