Migrate Pixtral inputs to TensorSchema (#23472)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck
2025-08-23 21:55:53 -07:00
committed by GitHub
parent c55c028998
commit 053278a5dc

View File

@@ -5,7 +5,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from functools import cached_property from functools import cached_property
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -48,6 +48,7 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer, from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config) cached_tokenizer_from_config)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
@@ -68,15 +69,20 @@ except ImportError:
PATCH_MERGE = "patch_merge" PATCH_MERGE = "patch_merge"
class PixtralImagePixelInputs(TypedDict): class PixtralImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
images: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_images, num_channels, image_width, image_height)` Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
The result of stacking `ImageEncoding.tokens` from each prompt. The result of stacking `ImageEncoding.tokens` from each prompt.
""" """
type: Literal["pixel_values"] = "pixel_values"
images: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})]
class PixtralProcessorAdapter: class PixtralProcessorAdapter:
@@ -381,10 +387,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if images is None: if images is None:
return None return None
if not isinstance(images, (torch.Tensor, list)):
raise ValueError("Incorrect type of images. "
f"Got type: {type(images)}")
return PixtralImagePixelInputs( return PixtralImagePixelInputs(
type="pixel_values", type="pixel_values",
images=flatten_bn(images), images=flatten_bn(images),