[VLM] Reorganize profiling/processing-related code (#11812)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-08 18:59:58 +08:00
committed by GitHub
parent f12141170a
commit 2a0596bc48
23 changed files with 833 additions and 760 deletions

View File

@@ -33,11 +33,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
@@ -64,24 +64,38 @@ class FuyuImagePatchInputs(TypedDict):
"""
class FuyuProcessingMixin(ProcessingMixin):
class FuyuProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(FuyuConfig)
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(FuyuProcessor)
def _get_image_processor(self) -> FuyuImageProcessor:
return self._get_hf_processor().image_processor
def get_image_processor(self) -> FuyuImageProcessor:
return self.get_hf_processor().image_processor
def _get_image_feature_grid_size(
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_ncols, max_nrows = self.get_image_feature_grid_size(
image_width=target_width,
image_height=target_height,
)
max_image_tokens = (max_ncols + 1) * max_nrows
return {"image": max_image_tokens}
def get_image_feature_grid_size(
self,
*,
image_width: int,
image_height: int,
) -> tuple[int, int]:
image_processor = self._get_image_processor()
image_processor = self.get_image_processor()
target_width = image_processor.size["width"]
target_height = image_processor.size["height"]
@@ -97,34 +111,21 @@ class FuyuProcessingMixin(ProcessingMixin):
nrows = math.ceil(image_height / 30)
return ncols, nrows
class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_size_with_most_features()
max_ncols, max_nrows = self._get_image_feature_grid_size(
image_width=target_width,
image_height=target_height,
)
max_image_tokens = (max_ncols + 1) * max_nrows
return {"image": max_image_tokens}
def _get_image_size_with_most_features(self) -> ImageSize:
image_processor = self._get_image_processor()
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
return ImageSize(width=image_processor.size["width"],
height=image_processor.size["height"])
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
@@ -140,10 +141,7 @@ class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo):
)
class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return FuyuProfilingInfo(self.ctx)
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
def _call_hf_processor(
self,
@@ -156,7 +154,7 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
# Avoid warning from HF logger for text-only input
# Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
# Tokenizer won't add boa_token_id by default, we add it manually.
tokenizer = self._get_tokenizer()
tokenizer = self.info.get_tokenizer()
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
@@ -196,10 +194,10 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id
tokenizer = self._get_tokenizer()
tokenizer = self.info.get_tokenizer()
eot_token_id = tokenizer.bos_token_id
assert isinstance(eot_token_id, int)
@@ -207,7 +205,7 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = self._get_image_feature_grid_size(
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image_size.width,
image_height=image_size.height,
)
@@ -244,7 +242,9 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
return result
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
info=FuyuProcessingInfo,
dummy_inputs=FuyuDummyInputsBuilder)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):