Migrate FuyuImagePatchInputs to TensorSchema (#21662)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.fuyu import FuyuImagePatchInputs
|
||||
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
|
||||
|
||||
|
||||
@@ -124,3 +125,24 @@ def test_tensor_schema_with_invalid_resolve_binding_dims():
|
||||
"w": 336
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_with_list_of_symbolic_dim():
|
||||
flat_data = torch.stack([torch.randn(768) for _ in range(3)]) # (bn=3, fn)
|
||||
patches_per_image = [64, 64, 64] # len = bn = 3
|
||||
|
||||
FuyuImagePatchInputs(
|
||||
flat_data=flat_data,
|
||||
patches_per_image=patches_per_image,
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length():
|
||||
flat_data = torch.stack([torch.randn(768) for _ in range(4)]) # (bn=4, fn)
|
||||
patches_per_image = [64, 64, 64] # len = 3 ≠ bn
|
||||
|
||||
with pytest.raises(ValueError, match="expected 'bn'=4, got 3"):
|
||||
FuyuImagePatchInputs(
|
||||
flat_data=flat_data,
|
||||
patches_per_image=patches_per_image,
|
||||
)
|
||||
Reference in New Issue
Block a user