[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-04 19:40:53 +08:00
committed by GitHub
parent 300acb8347
commit eed11ebee9
31 changed files with 1104 additions and 973 deletions

View File

@@ -25,11 +25,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
from vllm.multimodal.processing import (InputProcessingContext,
MultiModalDataItems, ProcessingCache,
ProcessorInputs, PromptReplacement,
full_groupby_modality)
ProcessorInputs, PromptReplacement)
from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel
@@ -39,7 +37,7 @@ from .pixtral import (PixtralHFVisionModel,
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import vision_encoder_info
from .vision import BaseVisionLanguageMultiModalProcessor
class LlavaImagePixelInputs(TypedDict):
@@ -100,19 +98,7 @@ class LlavaLikeConfig(Protocol):
vision_feature_layer: Final[Union[int, List[int]]]
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
def __init__(self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks)
vision_config = self._get_hf_config().vision_config
self._vision_encoder_info = vision_encoder_info(vision_config)
class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
@abstractmethod
def _get_hf_config(self) -> LlavaLikeConfig:
@@ -121,6 +107,19 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _apply_feature_select_strategy(
self,
strategy: str,
@@ -142,19 +141,6 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
self._vision_encoder_info.get_max_image_tokens(),
)
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_dummy_image_size(self) -> ImageSize:
image_size = self._vision_encoder_info.get_image_size()
return ImageSize(image_size, image_size)
@@ -163,8 +149,9 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
def _get_image_token(self) -> str:
raise NotImplementedError
def _get_dummy_mm_inputs(
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
@@ -709,7 +696,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
"</Image>)", # 3 tokens
])
mantis_repls = self._bind_prompt_replacements([
mantis_mm_repls = self._bind_and_group_repls([
PromptReplacement(
modality="image",
target=[image_token_id] * num_image_tokens,
@@ -719,7 +706,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
result["prompt_token_ids"],
mantis_repls,
mantis_mm_repls,
mm_item_counts,
)
@@ -728,15 +715,19 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
hf_processor_mm_kwargs,
mm_kwargs,
)
orig_repls = self._bind_prompt_replacements(unbound_orig_repls)
orig_repls = self._bind_and_group_repls(unbound_orig_repls)
all_placeholders = self._find_placeholders(orig_repls, prompt_ids,
mm_item_counts)
assert len(all_placeholders) == mm_item_counts.get("image", 0)
mm_placeholders = self._find_mm_placeholders(
orig_repls,
prompt_ids,
mm_item_counts,
)
mm_placeholders = {
modality: [item.to_range() for item in items]
for modality, items in full_groupby_modality(all_placeholders)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
mm_placeholder_ranges = {
modality: [item.to_range() for item in placeholders]
for modality, placeholders in mm_placeholders.items()
}
return MultiModalInputsV2(
@@ -744,7 +735,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt=prompt_text,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders,
mm_placeholders=mm_placeholder_ranges,
)