[VLM] Separate out profiling-related logic (#11746)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -4,8 +4,8 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (BatchFeature, Blip2Config, Blip2Processor,
|
||||
Blip2QFormerConfig, apply_chunking_to_forward)
|
||||
from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
|
||||
apply_chunking_to_forward)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
@@ -18,8 +18,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
MultiModalDataItems, ProcessingMixin,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .blip import BlipVisionModel
|
||||
@@ -396,20 +397,52 @@ class Blip2QFormerModel(nn.Module):
|
||||
return sequence_output
|
||||
|
||||
|
||||
class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
class Blip2ProcessingMixin(ProcessingMixin):
|
||||
|
||||
def _get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Blip2Config)
|
||||
|
||||
def _get_num_image_tokens(self) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
return hf_config.num_query_tokens
|
||||
|
||||
|
||||
class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def _get_num_image_tokens(self) -> int:
|
||||
hf_config = self.ctx.get_hf_config(Blip2Config)
|
||||
return hf_config.num_query_tokens
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_num_image_tokens()}
|
||||
|
||||
def _get_hf_processor(self) -> Blip2Processor:
|
||||
return self.ctx.get_hf_processor(Blip2Processor)
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self._get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
max_image_size = vision_config.image_size
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size,
|
||||
height=max_image_size,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return Blip2ProfilingInfo(self.ctx)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
@@ -427,13 +460,13 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
max_image_tokens = self._get_num_image_tokens()
|
||||
num_image_tokens = self._get_num_image_tokens()
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="</s>",
|
||||
replacement="<image>" * max_image_tokens + "</s>",
|
||||
replacement="<image>" * num_image_tokens + "</s>",
|
||||
)
|
||||
]
|
||||
|
||||
@@ -457,29 +490,6 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return result
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config(Blip2Config)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
max_image_size = vision_config.image_size
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size,
|
||||
height=max_image_size,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
Reference in New Issue
Block a user