[Models] Cohere ASR (#35809)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
@@ -19,8 +19,10 @@ import soundfile
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from evaluate import load
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
|
||||
from ....models.registry import HF_EXAMPLE_MODELS
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
@@ -64,8 +66,12 @@ async def bound_transcribe(sem, client, tokenizer, audio, reference):
|
||||
async def process_dataset(model, client, data, concurrent_request):
|
||||
sem = asyncio.Semaphore(concurrent_request)
|
||||
|
||||
# Load tokenizer once outside the loop
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
tokenizer = get_tokenizer(
|
||||
model,
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
)
|
||||
|
||||
# Warmup call as the first `librosa.load` server-side is quite slow.
|
||||
audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"]
|
||||
@@ -144,20 +150,35 @@ def run_evaluation(
|
||||
|
||||
|
||||
# alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo"..
|
||||
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"])
|
||||
# NOTE: Expected WER measured with equivalent hf.transformers args:
|
||||
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
|
||||
@pytest.mark.parametrize(
|
||||
"model_config",
|
||||
[
|
||||
("openai/whisper-large-v3", 12.744980),
|
||||
# TODO (ekagra): add HF ckpt after asr release
|
||||
# ("/host/engines/vllm/audio/2b-release", 11.73),
|
||||
],
|
||||
)
|
||||
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice.
|
||||
@pytest.mark.parametrize(
|
||||
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]
|
||||
)
|
||||
# NOTE: Expected WER measured with equivalent hf.transformers args:
|
||||
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
|
||||
@pytest.mark.parametrize("expected_wer", [12.744980])
|
||||
def test_wer_correctness(
|
||||
model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None
|
||||
model_config, dataset_repo, n_examples=-1, max_concurrent_request=None
|
||||
):
|
||||
model_name, expected_wer = model_config
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_name)
|
||||
# TODO refactor to use `ASRDataset`
|
||||
server_args = [
|
||||
"--enforce-eager",
|
||||
f"--tokenizer_mode={model_info.tokenizer_mode}",
|
||||
]
|
||||
if model_info.trust_remote_code:
|
||||
server_args.append("--trust-remote-code")
|
||||
with RemoteOpenAIServer(
|
||||
model_name, ["--enforce-eager"], max_wait_seconds=480
|
||||
model_name,
|
||||
server_args,
|
||||
) as remote_server:
|
||||
dataset = load_hf_dataset(dataset_repo)
|
||||
|
||||
@@ -167,7 +188,14 @@ def test_wer_correctness(
|
||||
|
||||
client = remote_server.get_async_client()
|
||||
wer = run_evaluation(
|
||||
model_name, client, dataset, max_concurrent_request, n_examples
|
||||
model_name,
|
||||
client,
|
||||
dataset,
|
||||
max_concurrent_request,
|
||||
n_examples,
|
||||
)
|
||||
|
||||
print(f"Expected WER: {expected_wer}, Actual WER: {wer}")
|
||||
|
||||
if expected_wer:
|
||||
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2)
|
||||
|
||||
Reference in New Issue
Block a user