Voxtral (#20970)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
committed by
GitHub
parent
4ffd963fa0
commit
e7e3e6d263
@@ -10,7 +10,7 @@ on HuggingFace model repository.
|
||||
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from typing import NamedTuple, Optional
|
||||
from typing import Any, NamedTuple, Optional
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
@@ -30,7 +30,9 @@ question_per_audio_count = {
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
engine_args: EngineArgs
|
||||
prompt: str
|
||||
prompt: Optional[str] = None
|
||||
prompt_token_ids: Optional[dict[str, list[int]]] = None
|
||||
multi_modal_data: Optional[dict[str, Any]] = None
|
||||
stop_token_ids: Optional[list[int]] = None
|
||||
lora_requests: Optional[list[LoRARequest]] = None
|
||||
|
||||
@@ -40,6 +42,60 @@ class ModelRequestData(NamedTuple):
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
# Voxtral
|
||||
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
AudioChunk,
|
||||
RawAudio,
|
||||
TextChunk,
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
model_name = "mistralai/Voxtral-Mini-3B-2507"
|
||||
tokenizer = MistralTokenizer.from_hf_hub(model_name)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
tokenizer_mode="mistral",
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
|
||||
text_chunk = TextChunk(text=question)
|
||||
audios = [
|
||||
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
|
||||
for i in range(audio_count)
|
||||
]
|
||||
audio_chunks = [
|
||||
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
|
||||
]
|
||||
|
||||
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
|
||||
|
||||
req = ChatCompletionRequest(messages=messages, model=model_name)
|
||||
|
||||
tokens = tokenizer.encode_chat_completion(req)
|
||||
prompt_ids, audios = tokens.tokens, tokens.audios
|
||||
|
||||
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
|
||||
|
||||
multi_modal_data = {"audio": audios_and_sr}
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt_token_ids=prompt_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
||||
# Granite Speech
|
||||
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
|
||||
# NOTE - the setting in this example are somehat different than what is
|
||||
@@ -243,6 +299,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"voxtral": run_voxtral,
|
||||
"granite_speech": run_granite_speech,
|
||||
"minicpmo": run_minicpmo,
|
||||
"phi4_mm": run_phi4mm,
|
||||
@@ -311,16 +368,24 @@ def main(args):
|
||||
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
|
||||
)
|
||||
|
||||
mm_data = {}
|
||||
if audio_count > 0:
|
||||
mm_data = {
|
||||
"audio": [
|
||||
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
|
||||
]
|
||||
}
|
||||
mm_data = req_data.multi_modal_data
|
||||
if not mm_data:
|
||||
mm_data = {}
|
||||
if audio_count > 0:
|
||||
mm_data = {
|
||||
"audio": [
|
||||
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
|
||||
]
|
||||
}
|
||||
|
||||
assert args.num_prompts > 0
|
||||
inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data}
|
||||
inputs = {"multi_modal_data": mm_data}
|
||||
|
||||
if req_data.prompt:
|
||||
inputs["prompt"] = req_data.prompt
|
||||
else:
|
||||
inputs["prompt_token_ids"] = req_data.prompt_token_ids
|
||||
|
||||
if args.num_prompts > 1:
|
||||
# Batch inference
|
||||
inputs = [inputs] * args.num_prompts
|
||||
|
||||
Reference in New Issue
Block a user