[Bugfix] Make MM batching more robust (#33817)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-06 04:40:58 +08:00
committed by GitHub
parent 4145e50d85
commit 116880a5a0
13 changed files with 625 additions and 428 deletions

View File

@@ -71,9 +71,7 @@ class Step3VLImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
patch_pixel_values: Annotated[
torch.Tensor | None, TensorShape("bnp", 3, "hp", "wp")
]
patch_pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "hp", "wp")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
@@ -91,7 +89,7 @@ class Step3VLImageEmbeddingInputs(TensorSchema):
Step3VLImageInputs: TypeAlias = Step3VLImagePixelInputs | Step3VLImageEmbeddingInputs
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[bool] | None]
MAX_IMAGE_SIZE: int = 3024
@@ -432,7 +430,7 @@ class Step3VLProcessor:
if len(parts) - 1 != len(repls):
raise ValueError(
"The number of placeholders does not match the number of replacements." # noqa: E501
"The number of placeholders does not match the number of replacements."
)
result = [parts[0]]
@@ -468,7 +466,7 @@ class Step3VLProcessor:
image_repl_str_lst = []
image_repl_ids_lst = []
num_patches = []
for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501
for raw_img, img_patches, patch_newline_mask in splitted_images_data:
pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img]))
if len(img_patches) > 0:
@@ -486,16 +484,20 @@ class Step3VLProcessor:
if patch_newline_mask is not None:
patch_newline_mask_lst.extend(patch_newline_mask)
pixel_values = torch.cat(pixel_values_lst)
patch_size = self.patch_size
image_inputs = {
"pixel_values": torch.cat(pixel_values_lst),
"pixel_values": pixel_values,
"num_patches": num_patches,
}
if patch_pixel_values_lst:
image_inputs["patch_pixel_values"] = torch.cat(patch_pixel_values_lst)
if patch_newline_mask_lst:
image_inputs["patch_newline_mask"] = torch.tensor(
"patch_pixel_values": (
torch.cat(patch_pixel_values_lst)
if patch_pixel_values_lst
else pixel_values.new_empty((0, 3, patch_size, patch_size))
),
"patch_newline_mask": torch.tensor(
patch_newline_mask_lst, dtype=torch.bool
)
),
}
text = [
self.replace_placeholder(t, self.image_token, image_repl_str_lst)
@@ -998,13 +1000,11 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
if pixel_values is not None and patch_pixel_values is not None:
return Step3VLImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values.to(self.dtype),
patch_pixel_values=patch_pixel_values.to(self.dtype)
if patch_pixel_values is not None
else None,
patch_pixel_values=patch_pixel_values.to(self.dtype),
num_patches=num_patches,
)
@@ -1039,7 +1039,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
image_features = self._get_vision_model_output(image_input["pixel_values"])
patch_image_features = (
self._get_vision_model_output(image_input["patch_pixel_values"])
if image_input["patch_pixel_values"] is not None
if len(image_input["patch_pixel_values"]) > 0
else None
)
num_patches = image_input["num_patches"]