[VLM] Separate out profiling-related logic (#11746)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -31,8 +31,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
MultiModalDataItems, ProcessingMixin,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
@@ -48,54 +49,33 @@ class ChameleonImagePixelInputs(TypedDict):
|
||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||
|
||||
|
||||
class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class ChameleonProcessingMixin(ProcessingMixin):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
def _get_hf_config(self):
|
||||
return self.ctx.get_hf_config(ChameleonConfig)
|
||||
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(ChameleonProcessor)
|
||||
|
||||
def _get_num_image_tokens(self) -> int:
|
||||
processor = self._get_hf_processor()
|
||||
return processor.image_seq_length
|
||||
|
||||
|
||||
class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_num_image_tokens()}
|
||||
|
||||
def _get_hf_processor(self) -> ChameleonProcessor:
|
||||
return self.ctx.get_hf_processor(ChameleonProcessor)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
processor = self._get_hf_processor()
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="<image>",
|
||||
replacement="".join([
|
||||
processor.image_start_token,
|
||||
processor.image_token * self._get_num_image_tokens(),
|
||||
processor.image_end_token,
|
||||
]),
|
||||
)
|
||||
]
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
config = self.ctx.get_hf_config(ChameleonConfig)
|
||||
config = self._get_hf_config()
|
||||
|
||||
width = height = config.vq_config.resolution
|
||||
num_images = mm_counts.get("image", 0)
|
||||
@@ -112,6 +92,40 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
|
||||
BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return ChameleonProfilingInfo(self.ctx)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
processor = self._get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="<image>",
|
||||
replacement="".join([
|
||||
processor.image_start_token,
|
||||
processor.image_token * self._get_num_image_tokens(),
|
||||
processor.image_end_token,
|
||||
]),
|
||||
)
|
||||
]
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
|
||||
Reference in New Issue
Block a user