[V1][VLM] Proper memory profiling for image language models (#11210)

Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: ywang96 <ywang@example.com>
This commit is contained in:
Roger Wang
2024-12-16 22:10:57 -08:00
committed by GitHub
parent 66d4b16724
commit 59c9b6ebeb
6 changed files with 98 additions and 13 deletions

View File

@@ -245,6 +245,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)
# If the last split index is the last index in image_tokens, we
# ignore it to avoid empty split tensor
if split_indices[-1] == len(image_tokens):
split_indices = split_indices[:-1]
image_embeds = image_embeds.tensor_split(split_indices.cpu())
return image_embeds