[Refactor] Decouple TimingContext from InputProcessingContext (#35083)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -37,12 +37,13 @@ from vllm.multimodal.inputs import (
|
||||
from vllm.multimodal.parse import (
|
||||
ImageProcessorItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalUUIDItems,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
ProcessorInputs,
|
||||
TimingContext,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -177,11 +178,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None = None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
|
||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
@@ -189,29 +187,30 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
Apply HF Processor on prompt text and multi-modal data together,
|
||||
outputting token IDs and processed tensors.
|
||||
"""
|
||||
if hf_processor_mm_kwargs is None:
|
||||
hf_processor_mm_kwargs = {}
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
prompt = inputs.prompt
|
||||
mm_items = inputs.mm_data_items
|
||||
hf_processor_mm_kwargs = inputs.hf_processor_mm_kwargs
|
||||
tokenization_kwargs = inputs.tokenization_kwargs
|
||||
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
if not isinstance(prompt, str):
|
||||
# the prompt is the tokenized ids which is not supported
|
||||
# by the hf_processor, which is why we would need to decode the ids
|
||||
# into string
|
||||
prompt = hf_processor.decode(prompt)
|
||||
with timing_ctx.record("apply_hf_processor"):
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
if not isinstance(prompt, str):
|
||||
# the prompt is the tokenized ids which is not supported
|
||||
# by the hf_processor, which is why we would need to decode the ids
|
||||
# into string
|
||||
prompt = hf_processor.decode(prompt)
|
||||
|
||||
# Bypass cached processor and always apply to the full set of mm inputs
|
||||
# NOTE: we can't just set caching=False because base class method
|
||||
# transforms outputs to `MultiModalKwargs` which is not going to
|
||||
# work for Transformers. We have a lot of logic tied to
|
||||
# `mm_tokens_per_modality` below
|
||||
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
|
||||
prompt_text=prompt,
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
# Bypass cached processor and always apply to the full set of mm inputs
|
||||
# NOTE: we can't just set caching=False because base class method
|
||||
# transforms outputs to `MultiModalKwargs` which is not going to
|
||||
# work for Transformers. We have a lot of logic tied to
|
||||
# `mm_tokens_per_modality` below
|
||||
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
|
||||
prompt_text=prompt,
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
# For gemma3 we check `token_type_ids` as the key
|
||||
token_type_key = (
|
||||
@@ -225,15 +224,14 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
# it for each input `mm_data`.
|
||||
mm_positions = torch.where(mm_token_type_ids == 1)[1]
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
multimodal_config = self.info.ctx.model_config.multimodal_config
|
||||
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
|
||||
image_sizes = []
|
||||
for item_idx in range(len(images)):
|
||||
image_size = images.get_image_size(item_idx)
|
||||
image_sizes.append((image_size.height, image_size.width))
|
||||
|
||||
mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
|
||||
image_sizes=image_sizes, **mm_processor_kwargs
|
||||
image_sizes=image_sizes,
|
||||
**self.info.ctx.get_merged_mm_kwargs({}),
|
||||
)
|
||||
|
||||
mm_placeholders = {}
|
||||
@@ -261,11 +259,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
)
|
||||
|
||||
# Use overrides if provided; fallback to data-dependent hashing.
|
||||
mm_hashes = self._hash_mm_items(
|
||||
mm_items,
|
||||
mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
with timing_ctx.record("get_mm_hashes"):
|
||||
mm_hashes = inputs.get_mm_hashes(self.info.model_id)
|
||||
|
||||
return mm_inputs(
|
||||
prompt_token_ids=prompt_ids,
|
||||
|
||||
Reference in New Issue
Block a user