[Bugfix] Fix InternVL2 inference with various num_patches (#8375)

Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Isotr0py
2024-09-13 01:10:35 +08:00
committed by GitHub
parent 520ca380ae
commit e56bf27741
2 changed files with 39 additions and 3 deletions

View File

@@ -270,6 +270,7 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
elif is_list_of(data, Image.Image):
# we can't stack here because the images may have different num_patches
data = [
image_to_pixel_values(img,
image_size,
@@ -277,7 +278,6 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
max_num,
use_thumbnail=use_thumbnail) for img in data
]
data = torch.stack(data)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
@@ -449,11 +449,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# We need to flatten (B, N, P) to (B*N*P),
# so we call flatten_bn twice.
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True).flatten(0, 1)),
flatten_bn(flatten_bn(pixel_values), concat=True)),
)
raise AssertionError("This line should be unreachable.")