[VLM] Various cleanup and fixes (#14806)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-14 20:58:19 +08:00
committed by GitHub
parent 40253bab44
commit ab93f1360f
14 changed files with 283 additions and 273 deletions

View File

@@ -32,7 +32,7 @@ from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
@@ -315,7 +315,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(flatten_bn(pixel_values)),
pixel_values=self._validate_pixel_values(
flatten_bn(pixel_values)),
image_sizes=self._validate_image_sizes(
flatten_bn(image_sizes, concat=True)),
)
@@ -434,7 +435,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
assert self.vision_tower is not None
pixel_values = inputs["data"]
pixel_values = inputs["pixel_values"]
if isinstance(pixel_values, torch.Tensor):
b, num_patches, c, h, w = pixel_values.shape