[VLM] Separate out profiling-related logic (#11746)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-06 16:02:21 +08:00
committed by GitHub
parent 2a622d704a
commit 996357e480
17 changed files with 1036 additions and 739 deletions

View File

@@ -4,8 +4,8 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
import torch
import torch.nn as nn
from transformers import (BatchFeature, Blip2Config, Blip2Processor,
Blip2QFormerConfig, apply_chunking_to_forward)
from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
apply_chunking_to_forward)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
@@ -18,8 +18,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .blip import BlipVisionModel
@@ -396,20 +397,52 @@ class Blip2QFormerModel(nn.Module):
return sequence_output
class Blip2MultiModalProcessor(BaseMultiModalProcessor):
class Blip2ProcessingMixin(ProcessingMixin):
def _get_hf_config(self):
return self.ctx.get_hf_config(Blip2Config)
def _get_num_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return hf_config.num_query_tokens
class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def _get_num_image_tokens(self) -> int:
hf_config = self.ctx.get_hf_config(Blip2Config)
return hf_config.num_query_tokens
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
def _get_hf_processor(self) -> Blip2Processor:
return self.ctx.get_hf_processor(Blip2Processor)
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self._get_hf_config()
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Blip2ProfilingInfo(self.ctx)
def _get_mm_fields_config(
self,
@@ -427,13 +460,13 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
max_image_tokens = self._get_num_image_tokens()
num_image_tokens = self._get_num_image_tokens()
return [
PromptReplacement(
modality="image",
target="</s>",
replacement="<image>" * max_image_tokens + "</s>",
replacement="<image>" * num_image_tokens + "</s>",
)
]
@@ -457,29 +490,6 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
return result
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):