[Core] Update dtype detection and defaults (#14858)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import BatchEncoding, Qwen2VLForConditionalGeneration
|
||||
from transformers import Qwen2VLForConditionalGeneration
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||
from ....utils import large_gpu_test
|
||||
@@ -75,10 +75,6 @@ def apply_chat_template_and_add_eos(
|
||||
return prompt
|
||||
|
||||
|
||||
def postprocess_inputs(hf_model: HfRunner, inputs: BatchEncoding, **kwargs):
|
||||
return hf_model.model.prepare_inputs_for_generation(**inputs, **kwargs)
|
||||
|
||||
|
||||
def _run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
@@ -118,14 +114,8 @@ def _run_test(
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
auto_cls=Qwen2VLForConditionalGeneration) as hf_model:
|
||||
hf_model.postprocess_inputs = partial(
|
||||
postprocess_inputs,
|
||||
hf_model,
|
||||
cache_position=torch.arange(
|
||||
0,
|
||||
1, # 1 for batch size
|
||||
requires_grad=False),
|
||||
use_cache=False)
|
||||
|
||||
prompts = []
|
||||
for text, image, embed_text in zip(input_texts, input_images,
|
||||
embed_texts):
|
||||
# dse requires non-standard input processing
|
||||
@@ -133,20 +123,34 @@ def _run_test(
|
||||
messages = get_messages(image, text, embed_text)
|
||||
prompt = apply_chat_template_and_add_eos(
|
||||
messages, hf_model.processor.apply_chat_template)
|
||||
inputs = hf_model.get_inputs(
|
||||
prompts=[[prompt]],
|
||||
images=[[image]],
|
||||
)
|
||||
with torch.no_grad():
|
||||
|
||||
prompts.append(prompt)
|
||||
|
||||
all_inputs = hf_model.get_inputs(
|
||||
prompts=prompts,
|
||||
images=input_images,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
all_outputs = []
|
||||
for inputs in all_inputs:
|
||||
inputs = hf_model.model.prepare_inputs_for_generation(
|
||||
**inputs,
|
||||
cache_position=torch.arange(1), # 1 for batch size
|
||||
use_cache=False,
|
||||
)
|
||||
outputs = hf_model.model(
|
||||
**hf_model.wrap_device(inputs[0],
|
||||
device=hf_model.model.device.type),
|
||||
**hf_model.wrap_device(inputs),
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
pooled_output = torch.nn.functional.normalize(
|
||||
outputs.hidden_states[-1][0, -1], p=2, dim=-1)
|
||||
hf_outputs.append(pooled_output.tolist())
|
||||
pooled_output = F.normalize(outputs.hidden_states[-1][0, -1],
|
||||
p=2,
|
||||
dim=-1)
|
||||
|
||||
all_outputs.append(pooled_output.tolist())
|
||||
|
||||
hf_outputs = all_outputs
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
|
||||
Reference in New Issue
Block a user