[VLM] Various cleanup and fixes (#14806)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -32,7 +32,7 @@ from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
|
||||
|
||||
class LlavaNextImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
|
||||
@@ -315,7 +315,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return LlavaNextImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(flatten_bn(pixel_values)),
|
||||
pixel_values=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values)),
|
||||
image_sizes=self._validate_image_sizes(
|
||||
flatten_bn(image_sizes, concat=True)),
|
||||
)
|
||||
@@ -434,7 +435,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
if isinstance(pixel_values, torch.Tensor):
|
||||
b, num_patches, c, h, w = pixel_values.shape
|
||||
|
||||
Reference in New Issue
Block a user