[Core] Update dtype detection and defaults (#14858)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -4,8 +4,7 @@ from typing import Optional, overload
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import (AutoConfig, AutoModelForImageTextToText,
|
||||
AutoTokenizer, BatchEncoding)
|
||||
from transformers import AutoConfig, AutoModelForImageTextToText, AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
@@ -227,13 +226,9 @@ def _run_test(
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
def process(hf_inputs: BatchEncoding, **kwargs):
|
||||
return hf_inputs
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
model_kwargs={"device_map": "auto"},
|
||||
postprocess_inputs=process,
|
||||
auto_cls=AutoModelForImageTextToText) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
|
||||
Reference in New Issue
Block a user