[VLM] Separate out profiling-related logic (#11746)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-06 16:02:21 +08:00
committed by GitHub
parent 2a622d704a
commit 996357e480
17 changed files with 1036 additions and 739 deletions

View File

@@ -31,8 +31,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
@@ -48,54 +49,33 @@ class ChameleonImagePixelInputs(TypedDict):
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
class ChameleonProcessingMixin(ProcessingMixin):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def _get_hf_config(self):
return self.ctx.get_hf_config(ChameleonConfig)
def _get_hf_processor(self):
return self.ctx.get_hf_processor(ChameleonProcessor)
def _get_num_image_tokens(self) -> int:
processor = self._get_hf_processor()
return processor.image_seq_length
class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
def _get_hf_processor(self) -> ChameleonProcessor:
return self.ctx.get_hf_processor(ChameleonProcessor)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self._get_hf_processor()
return [
PromptReplacement(
modality="image",
target="<image>",
replacement="".join([
processor.image_start_token,
processor.image_token * self._get_num_image_tokens(),
processor.image_end_token,
]),
)
]
def _get_dummy_processor_inputs(
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
config = self.ctx.get_hf_config(ChameleonConfig)
config = self._get_hf_config()
width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0)
@@ -112,6 +92,40 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
mm_data=mm_data,
)
class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return ChameleonProfilingInfo(self.ctx)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self._get_hf_processor(**hf_processor_mm_kwargs)
return [
PromptReplacement(
modality="image",
target="<image>",
replacement="".join([
processor.image_start_token,
processor.image_token * self._get_num_image_tokens(),
processor.image_end_token,
]),
)
]
def apply(
self,
prompt_text: str,