[Misc] Clean up input processing (#17582)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -6,6 +6,7 @@ from huggingface_hub import snapshot_download
|
||||
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
|
||||
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from ....conftest import ImageTestAssets
|
||||
|
||||
@@ -14,6 +15,7 @@ from ....conftest import ImageTestAssets
|
||||
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_intern_vit_test(
|
||||
image_assets: ImageTestAssets,
|
||||
model_id: str,
|
||||
@@ -21,11 +23,12 @@ def run_intern_vit_test(
|
||||
dtype: str,
|
||||
):
|
||||
model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
img_processor = CLIPImageProcessor.from_pretrained(model)
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
pixel_values = [
|
||||
img_processor(images, return_tensors='pt').pixel_values.to(dtype)
|
||||
img_processor(images, return_tensors='pt').pixel_values.to(torch_dtype)
|
||||
for images in images
|
||||
]
|
||||
|
||||
@@ -34,7 +37,7 @@ def run_intern_vit_test(
|
||||
config.norm_type = "rms_norm"
|
||||
|
||||
hf_model = AutoModel.from_pretrained(model,
|
||||
torch_dtype=dtype,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True).to("cuda")
|
||||
hf_outputs_per_image = [
|
||||
hf_model(pixel_value.to("cuda")).last_hidden_state
|
||||
@@ -48,7 +51,7 @@ def run_intern_vit_test(
|
||||
del hf_model
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
vllm_model = vllm_model.to("cuda", dtype)
|
||||
vllm_model = vllm_model.to("cuda", torch_dtype)
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model(pixel_values=pixel_value.to("cuda"))
|
||||
for pixel_value in pixel_values
|
||||
@@ -66,9 +69,8 @@ def run_intern_vit_test(
|
||||
"OpenGVLab/InternViT-300M-448px",
|
||||
"OpenGVLab/InternViT-6B-448px-V1-5",
|
||||
])
|
||||
@pytest.mark.parametrize("dtype", [torch.half])
|
||||
@torch.inference_mode()
|
||||
def test_models(image_assets, model_id, dtype: str) -> None:
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models(dist_init, image_assets, model_id, dtype: str) -> None:
|
||||
run_intern_vit_test(
|
||||
image_assets,
|
||||
model_id,
|
||||
|
||||
Reference in New Issue
Block a user