[VLM] Reorganize profiling/processing-related code (#11812)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-08 18:59:58 +08:00
committed by GitHub
parent f12141170a
commit 2a0596bc48
23 changed files with 833 additions and 760 deletions

View File

@@ -17,10 +17,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .blip import BlipVisionModel
@@ -397,30 +397,30 @@ class Blip2QFormerModel(nn.Module):
return sequence_output
class Blip2ProcessingMixin(ProcessingMixin):
class Blip2ProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(Blip2Config)
def _get_num_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return hf_config.num_query_tokens
class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
return hf_config.num_query_tokens
class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size
@@ -439,10 +439,7 @@ class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
)
class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Blip2ProfilingInfo(self.ctx)
class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
def _get_mm_fields_config(
self,
@@ -460,7 +457,7 @@ class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
num_image_tokens = self._get_num_image_tokens()
num_image_tokens = self.info.get_num_image_tokens()
return [
PromptReplacement(
@@ -491,7 +488,9 @@ class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
return result
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
info=Blip2ProcessingInfo,
dummy_inputs=Blip2DummyInputsBuilder)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):