[Core] Registry for processing model inputs (#5214)

Co-authored-by: ywang96 <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2024-06-28 20:09:56 +08:00
committed by GitHub
parent 0d0e3a42ac
commit 5cbe8d155c
26 changed files with 784 additions and 398 deletions

View File

@@ -25,14 +25,14 @@ def test_clip_image_processor(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
),
)
for asset in image_assets:
@@ -40,10 +40,9 @@ def test_clip_image_processor(image_assets, dtype):
asset.pil_image,
return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
vllm_result = MULTIMODAL_REGISTRY.process_input(
vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
model_config=model_config,
vlm_config=vlm_config,
)
assert hf_result.keys() == vllm_result.keys()
@@ -74,14 +73,14 @@ def test_llava_next_image_processor(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=64000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=2928,
image_processor=MODEL_NAME,
image_processor_revision=None,
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=64000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=2928,
image_processor=MODEL_NAME,
image_processor_revision=None,
),
)
for asset in image_assets:
@@ -89,10 +88,9 @@ def test_llava_next_image_processor(image_assets, dtype):
asset.pil_image,
return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
vllm_result = MULTIMODAL_REGISTRY.process_input(
vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
model_config=model_config,
vlm_config=vlm_config,
)
assert hf_result.keys() == vllm_result.keys()
@@ -119,26 +117,23 @@ def test_image_pixel_types(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
)
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
))
for asset in image_assets:
image_result = MULTIMODAL_REGISTRY.process_input(
image_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
model_config=model_config,
vlm_config=vlm_config,
)
tensor_result = MULTIMODAL_REGISTRY.process_input(
tensor_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pixel_values),
model_config=model_config,
vlm_config=vlm_config,
)
assert image_result.keys() == tensor_result.keys()