[Bugfix] Fix profiling dummy data for Pixtral (#18677)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-05-25 22:05:30 +08:00
committed by GitHub
parent 3a886bd58c
commit 57fd13a707
5 changed files with 151 additions and 168 deletions

View File

@@ -9,7 +9,9 @@ from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image
from transformers import PixtralVisionConfig, TensorType
@@ -39,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, MultiModalHashes,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
@@ -224,6 +226,28 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
num_images=num_images)
}
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
dummy_images = dummy_mm_data.get("image", [])
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=dummy_text),
*(ImageChunk(image=image) for image in dummy_images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
):
@@ -275,8 +299,12 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
prompt_ids, mm_kwargs, mm_hashes, _ = super(
)._cached_apply_hf_processor(
(
prompt_ids,
mm_kwargs,
mm_hashes,
_,
) = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,