[Core] Automatically cast multi-modal input dtype (#18756)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -210,9 +210,7 @@ class DeepseekVL2MultiModalProcessor(
|
||||
dict(prompt=prompt, **mm_data),
|
||||
mm_kwargs,
|
||||
)
|
||||
target_dtype = self.info.ctx.model_config.dtype
|
||||
pixel_values = processed_outputs.pop("pixel_values").to(
|
||||
target_dtype)
|
||||
pixel_values = processed_outputs["pixel_values"]
|
||||
# split pixel values into patches corresponding to each image
|
||||
images_spatial_crop = processed_outputs["images_spatial_crop"]
|
||||
patches_per_image = [
|
||||
|
||||
@@ -263,11 +263,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
mm_data,
|
||||
mm_kwargs,
|
||||
)
|
||||
if "pixel_values" in processed_outputs:
|
||||
# Cast pixel values to model dtype already here,
|
||||
# so we need to transfer less data to the GPU
|
||||
processed_outputs["pixel_values"] = processed_outputs[
|
||||
"pixel_values"].to(self.info.ctx.model_config.dtype)
|
||||
|
||||
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
|
||||
if (images := mm_data.get("images")) is not None:
|
||||
|
||||
Reference in New Issue
Block a user