[Transformers v5] fix missing pixtral/voxtral multimodal dispatch (#38410)

Signed-off-by: allgather <all2allops@gmail.com>
This commit is contained in:
allgather
2026-03-29 02:59:06 -07:00
committed by GitHub
parent 43cc5138e5
commit 8c0b6267d7
4 changed files with 40 additions and 22 deletions

View File

@@ -61,7 +61,10 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.transformers_utils.processors.pixtral import MistralCommonPixtralProcessor
from vllm.transformers_utils.processors.pixtral import (
MistralCommonImageProcessor,
MistralCommonPixtralProcessor,
)
from vllm.utils.collection_utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -128,18 +131,20 @@ class PixtralProcessingInfo(BaseProcessingInfo):
return tokenizer
def get_image_processor(self) -> MistralCommonImageProcessor:
return MistralCommonImageProcessor(self.get_tokenizer().instruct.mm_encoder)
def get_hf_processor(self, **kwargs) -> MistralCommonPixtralProcessor:
return self.ctx.init_processor(
MistralCommonPixtralProcessor,
return MistralCommonPixtralProcessor(
tokenizer=self.get_tokenizer(),
**kwargs,
image_processor=self.get_image_processor(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_hf_processor().image_processor
image_processor = self.get_image_processor()
max_image_size = image_processor.mm_encoder.mm_config.max_image_size
return ImageSize(width=max_image_size, height=max_image_size)

View File

@@ -55,7 +55,10 @@ from vllm.multimodal.processing.processor import (
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.transformers_utils.processors.voxtral import MistralCommonVoxtralProcessor
from vllm.transformers_utils.processors.voxtral import (
MistralCommonFeatureExtractor,
MistralCommonVoxtralProcessor,
)
from vllm.utils.collection_utils import is_list_of
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
@@ -84,15 +87,19 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
return tokenizer
def get_feature_extractor(self) -> MistralCommonFeatureExtractor:
return MistralCommonFeatureExtractor(
self.get_tokenizer().instruct.audio_encoder
)
def get_hf_processor(self, **kwargs) -> MistralCommonVoxtralProcessor:
return self.ctx.init_processor(
MistralCommonVoxtralProcessor,
return MistralCommonVoxtralProcessor(
tokenizer=self.get_tokenizer(),
**kwargs,
feature_extractor=self.get_feature_extractor(),
)
def get_data_parser(self):
feature_extractor = self.get_hf_processor().feature_extractor
feature_extractor = self.get_feature_extractor()
return MultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
@@ -114,7 +121,7 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
return self.ctx.model_config.max_model_len
def get_max_audio_array_len(self) -> int:
feature_extractor = self.get_hf_processor().feature_extractor
feature_extractor = self.get_feature_extractor()
return self.get_max_audio_tokens() * int(
feature_extractor.sampling_rate // feature_extractor.frame_rate
@@ -153,7 +160,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
mm_data: MultiModalDataDict | None = None,
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
feature_extractor = self.info.get_hf_processor().feature_extractor
feature_extractor = self.info.get_feature_extractor()
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = (
@@ -480,8 +487,10 @@ class VoxtralForConditionalGeneration(
This is used for estimating the amount of processing for this audio.
"""
tokenizer = cached_tokenizer_from_config(model_config)
adapter = MistralCommonVoxtralProcessor(tokenizer)
return adapter.feature_extractor.get_num_audio_tokens(
feature_extractor = MistralCommonFeatureExtractor(
tokenizer.instruct.audio_encoder
)
return feature_extractor.get_num_audio_tokens(
int(audio_duration_s * stt_config.sample_rate)
)