[Bugfix] Fix profiling dummy data for Pixtral (#18677)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -9,7 +9,9 @@ from typing import Literal, Optional, TypedDict, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mistral_common.protocol.instruct.messages import ImageChunk
|
||||
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
||||
UserMessage)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
|
||||
from PIL import Image
|
||||
from transformers import PixtralVisionConfig, TensorType
|
||||
@@ -39,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, MultiModalHashes,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
@@ -224,6 +226,28 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
|
||||
dummy_images = dummy_mm_data.get("image", [])
|
||||
|
||||
request = ChatCompletionRequest(messages=[
|
||||
UserMessage(content=[
|
||||
TextChunk(text=dummy_text),
|
||||
*(ImageChunk(image=image) for image in dummy_images),
|
||||
]),
|
||||
])
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
dummy_tokens = res.tokens
|
||||
|
||||
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
|
||||
|
||||
|
||||
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
):
|
||||
@@ -275,8 +299,12 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
*,
|
||||
return_mm_hashes: bool,
|
||||
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
|
||||
prompt_ids, mm_kwargs, mm_hashes, _ = super(
|
||||
)._cached_apply_hf_processor(
|
||||
(
|
||||
prompt_ids,
|
||||
mm_kwargs,
|
||||
mm_hashes,
|
||||
_,
|
||||
) = super()._cached_apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user