[Bugfix] Make MM batching more robust (#33817)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -71,9 +71,7 @@ class Step3VLImagePixelInputs(TensorSchema):
|
||||
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||
patch_pixel_values: Annotated[
|
||||
torch.Tensor | None, TensorShape("bnp", 3, "hp", "wp")
|
||||
]
|
||||
patch_pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "hp", "wp")]
|
||||
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
|
||||
|
||||
@@ -91,7 +89,7 @@ class Step3VLImageEmbeddingInputs(TensorSchema):
|
||||
|
||||
Step3VLImageInputs: TypeAlias = Step3VLImagePixelInputs | Step3VLImageEmbeddingInputs
|
||||
|
||||
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
|
||||
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[bool] | None]
|
||||
|
||||
MAX_IMAGE_SIZE: int = 3024
|
||||
|
||||
@@ -432,7 +430,7 @@ class Step3VLProcessor:
|
||||
|
||||
if len(parts) - 1 != len(repls):
|
||||
raise ValueError(
|
||||
"The number of placeholders does not match the number of replacements." # noqa: E501
|
||||
"The number of placeholders does not match the number of replacements."
|
||||
)
|
||||
|
||||
result = [parts[0]]
|
||||
@@ -468,7 +466,7 @@ class Step3VLProcessor:
|
||||
image_repl_str_lst = []
|
||||
image_repl_ids_lst = []
|
||||
num_patches = []
|
||||
for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501
|
||||
for raw_img, img_patches, patch_newline_mask in splitted_images_data:
|
||||
pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img]))
|
||||
|
||||
if len(img_patches) > 0:
|
||||
@@ -486,16 +484,20 @@ class Step3VLProcessor:
|
||||
if patch_newline_mask is not None:
|
||||
patch_newline_mask_lst.extend(patch_newline_mask)
|
||||
|
||||
pixel_values = torch.cat(pixel_values_lst)
|
||||
patch_size = self.patch_size
|
||||
image_inputs = {
|
||||
"pixel_values": torch.cat(pixel_values_lst),
|
||||
"pixel_values": pixel_values,
|
||||
"num_patches": num_patches,
|
||||
}
|
||||
if patch_pixel_values_lst:
|
||||
image_inputs["patch_pixel_values"] = torch.cat(patch_pixel_values_lst)
|
||||
if patch_newline_mask_lst:
|
||||
image_inputs["patch_newline_mask"] = torch.tensor(
|
||||
"patch_pixel_values": (
|
||||
torch.cat(patch_pixel_values_lst)
|
||||
if patch_pixel_values_lst
|
||||
else pixel_values.new_empty((0, 3, patch_size, patch_size))
|
||||
),
|
||||
"patch_newline_mask": torch.tensor(
|
||||
patch_newline_mask_lst, dtype=torch.bool
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
text = [
|
||||
self.replace_placeholder(t, self.image_token, image_repl_str_lst)
|
||||
@@ -998,13 +1000,11 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
if pixel_values is not None and patch_pixel_values is not None:
|
||||
return Step3VLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values.to(self.dtype),
|
||||
patch_pixel_values=patch_pixel_values.to(self.dtype)
|
||||
if patch_pixel_values is not None
|
||||
else None,
|
||||
patch_pixel_values=patch_pixel_values.to(self.dtype),
|
||||
num_patches=num_patches,
|
||||
)
|
||||
|
||||
@@ -1039,7 +1039,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
image_features = self._get_vision_model_output(image_input["pixel_values"])
|
||||
patch_image_features = (
|
||||
self._get_vision_model_output(image_input["patch_pixel_values"])
|
||||
if image_input["patch_pixel_values"] is not None
|
||||
if len(image_input["patch_pixel_values"]) > 0
|
||||
else None
|
||||
)
|
||||
num_patches = image_input["num_patches"]
|
||||
|
||||
Reference in New Issue
Block a user