[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -25,11 +25,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
InputProcessingContext,
|
||||
from vllm.multimodal.processing import (InputProcessingContext,
|
||||
MultiModalDataItems, ProcessingCache,
|
||||
ProcessorInputs, PromptReplacement,
|
||||
full_groupby_modality)
|
||||
ProcessorInputs, PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
@@ -39,7 +37,7 @@ from .pixtral import (PixtralHFVisionModel,
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import vision_encoder_info
|
||||
from .vision import BaseVisionLanguageMultiModalProcessor
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
@@ -100,19 +98,7 @@ class LlavaLikeConfig(Protocol):
|
||||
vision_feature_layer: Final[Union[int, List[int]]]
|
||||
|
||||
|
||||
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True) -> None:
|
||||
super().__init__(ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks)
|
||||
|
||||
vision_config = self._get_hf_config().vision_config
|
||||
self._vision_encoder_info = vision_encoder_info(vision_config)
|
||||
class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
|
||||
@abstractmethod
|
||||
def _get_hf_config(self) -> LlavaLikeConfig:
|
||||
@@ -121,6 +107,19 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_max_image_tokens()}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _apply_feature_select_strategy(
|
||||
self,
|
||||
strategy: str,
|
||||
@@ -142,19 +141,6 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
self._vision_encoder_info.get_max_image_tokens(),
|
||||
)
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": self._get_max_image_tokens()}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
image_size = self._vision_encoder_info.get_image_size()
|
||||
return ImageSize(image_size, image_size)
|
||||
@@ -163,8 +149,9 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def _get_image_token(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
@@ -709,7 +696,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
"</Image>)", # 3 tokens
|
||||
])
|
||||
|
||||
mantis_repls = self._bind_prompt_replacements([
|
||||
mantis_mm_repls = self._bind_and_group_repls([
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id] * num_image_tokens,
|
||||
@@ -719,7 +706,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
|
||||
result["prompt_token_ids"],
|
||||
mantis_repls,
|
||||
mantis_mm_repls,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
@@ -728,15 +715,19 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
orig_repls = self._bind_prompt_replacements(unbound_orig_repls)
|
||||
orig_repls = self._bind_and_group_repls(unbound_orig_repls)
|
||||
|
||||
all_placeholders = self._find_placeholders(orig_repls, prompt_ids,
|
||||
mm_item_counts)
|
||||
assert len(all_placeholders) == mm_item_counts.get("image", 0)
|
||||
mm_placeholders = self._find_mm_placeholders(
|
||||
orig_repls,
|
||||
prompt_ids,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
mm_placeholders = {
|
||||
modality: [item.to_range() for item in items]
|
||||
for modality, items in full_groupby_modality(all_placeholders)
|
||||
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||
|
||||
mm_placeholder_ranges = {
|
||||
modality: [item.to_range() for item in placeholders]
|
||||
for modality, placeholders in mm_placeholders.items()
|
||||
}
|
||||
|
||||
return MultiModalInputsV2(
|
||||
@@ -744,7 +735,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_placeholders=mm_placeholders,
|
||||
mm_placeholders=mm_placeholder_ranges,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user