[VLM] Reorganize profiling/processing-related code (#11812)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-08 18:59:58 +08:00
committed by GitHub
parent f12141170a
commit 2a0596bc48
23 changed files with 833 additions and 760 deletions

View File

@@ -34,13 +34,12 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize)
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement,
_BoundPromptReplacement,
_PlaceholderInfo)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo,
BoundPromptReplacement,
PlaceholderInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@@ -302,9 +301,9 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline
class Phi3VProcessingMixin(ProcessingMixin):
class Phi3VProcessingInfo(BaseProcessingInfo):
def _get_hf_processor(
def get_hf_processor(
self,
*,
num_crops: Optional[int] = None,
@@ -314,39 +313,42 @@ class Phi3VProcessingMixin(ProcessingMixin):
return self.ctx.get_hf_processor()
def _get_num_image_tokens(
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
return {"image": max_image_tokens}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[ProcessorMixin],
) -> int:
processor = self._get_hf_processor()
if processor is None:
processor = self.get_hf_processor()
return processor.calc_num_image_tokens_from_image_size( # type: ignore
width=image_width,
height=image_height,
)
class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_size_with_most_features()
max_image_tokens = self._get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
return {"image": max_image_tokens}
def _get_image_size_with_most_features(self) -> ImageSize:
def get_image_size_with_most_features(self) -> ImageSize:
# Result in the max possible feature size (h:w = 16:1)
return ImageSize(height=8000, width=50)
class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
@@ -354,7 +356,8 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
"image":
@@ -363,7 +366,7 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
num_images=num_images)
}
hf_processor = self._get_hf_processor()
hf_processor = self.info.get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
return ProcessorInputs(
@@ -372,10 +375,7 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
)
class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Phi3VProfilingInfo(self.ctx)
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
def _call_hf_processor(
self,
@@ -416,10 +416,10 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
tokenizer = self._get_tokenizer()
tokenizer = self.info.get_tokenizer()
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
@@ -431,9 +431,10 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self._get_num_image_tokens(
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id]
@@ -451,9 +452,9 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
def _apply_prompt_replacements(
self,
token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids,
mm_prompt_repls=mm_prompt_repls,
@@ -466,7 +467,7 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
token_ids = [token_ids[0], *token_ids[2:]]
placeholders = {
modality: [
_PlaceholderInfo(
PlaceholderInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=p.start_idx - 1,
@@ -499,7 +500,9 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
return result
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
info=Phi3VProcessingInfo,
dummy_inputs=Phi3VDummyInputsBuilder)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={