[VLM] Reorganize profiling/processing-related code (#11812)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-08 18:59:58 +08:00
committed by GitHub
parent f12141170a
commit 2a0596bc48
23 changed files with 833 additions and 760 deletions

View File

@@ -30,10 +30,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
@@ -49,33 +49,34 @@ class ChameleonImagePixelInputs(TypedDict):
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class ChameleonProcessingMixin(ProcessingMixin):
class ChameleonProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(ChameleonConfig)
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(ChameleonProcessor)
def _get_num_image_tokens(self) -> int:
processor = self._get_hf_processor()
return processor.image_seq_length
class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
processor = self.get_hf_processor()
return processor.image_seq_length
class ChameleonDummyInputsBuilder(
BaseDummyInputsBuilder[ChameleonProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
config = self._get_hf_config()
config = self.info.get_hf_config()
width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0)
@@ -93,11 +94,8 @@ class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
)
class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return ChameleonProfilingInfo(self.ctx)
class ChameleonMultiModalProcessor(
BaseMultiModalProcessor[ChameleonProcessingInfo]):
def _get_mm_fields_config(
self,
@@ -112,7 +110,7 @@ class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self._get_hf_processor(**hf_processor_mm_kwargs)
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
return [
PromptReplacement(
@@ -120,7 +118,7 @@ class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
target="<image>",
replacement="".join([
processor.image_start_token,
processor.image_token * self._get_num_image_tokens(),
processor.image_token * self.info.get_num_image_tokens(),
processor.image_end_token,
]),
)
@@ -916,7 +914,10 @@ class ChameleonModel(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(
ChameleonMultiModalProcessor,
info=ChameleonProcessingInfo,
dummy_inputs=ChameleonDummyInputsBuilder)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):