[Core] Update dtype detection and defaults (#14858)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -5,11 +5,10 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from transformers import AutoModel, AutoTokenizer, BatchEncoding
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from vllm.multimodal.audio import resample_audio
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from ....conftest import HfRunner, VllmRunner
|
||||
from ....utils import RemoteOpenAIServer
|
||||
@@ -107,8 +106,6 @@ def run_test(
|
||||
**kwargs,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm."""
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
@@ -124,15 +121,7 @@ def run_test(
|
||||
for vllm_prompt, _, audio in prompts_and_audios
|
||||
]
|
||||
|
||||
def process(hf_inputs: BatchEncoding, **kwargs):
|
||||
hf_inputs["audio_values"] = hf_inputs["audio_values"] \
|
||||
.to(torch_dtype) # type: ignore
|
||||
return hf_inputs
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
postprocess_inputs=process,
|
||||
auto_cls=AutoModel) as hf_model:
|
||||
with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model:
|
||||
hf_outputs_per_audio = [
|
||||
hf_model.generate_greedy_logprobs_limit(
|
||||
[hf_prompt],
|
||||
|
||||
Reference in New Issue
Block a user