Migrate Pixtral inputs to TensorSchema (#23472)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import math
|
|||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Literal, Optional, TypedDict, Union
|
from typing import Annotated, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -48,6 +48,7 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||||
cached_tokenizer_from_config)
|
cached_tokenizer_from_config)
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||||
@@ -68,15 +69,20 @@ except ImportError:
|
|||||||
PATCH_MERGE = "patch_merge"
|
PATCH_MERGE = "patch_merge"
|
||||||
|
|
||||||
|
|
||||||
class PixtralImagePixelInputs(TypedDict):
|
class PixtralImagePixelInputs(TensorSchema):
|
||||||
type: Literal["pixel_values"]
|
|
||||||
|
|
||||||
images: Union[torch.Tensor, list[torch.Tensor]]
|
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size * num_images, num_channels, image_width, image_height)`
|
Dimensions:
|
||||||
|
- bn: Batch size * number of images
|
||||||
|
- c: Number of channels (3)
|
||||||
|
- h: Height of each image
|
||||||
|
- w: Width of each image
|
||||||
|
|
||||||
The result of stacking `ImageEncoding.tokens` from each prompt.
|
The result of stacking `ImageEncoding.tokens` from each prompt.
|
||||||
"""
|
"""
|
||||||
|
type: Literal["pixel_values"] = "pixel_values"
|
||||||
|
|
||||||
|
images: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||||
|
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})]
|
||||||
|
|
||||||
|
|
||||||
class PixtralProcessorAdapter:
|
class PixtralProcessorAdapter:
|
||||||
@@ -381,10 +387,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
if images is None:
|
if images is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not isinstance(images, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of images. "
|
|
||||||
f"Got type: {type(images)}")
|
|
||||||
|
|
||||||
return PixtralImagePixelInputs(
|
return PixtralImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
images=flatten_bn(images),
|
images=flatten_bn(images),
|
||||||
|
|||||||
Reference in New Issue
Block a user