[VLM] Support caching in merged multi-modal processor (#11396)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -12,9 +12,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -32,10 +32,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
|
||||
MultiModalFieldConfig, MultiModalInputsV2,
|
||||
MultiModalKwargs, NestedTensors,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
ProcessorInputs, PromptReplacement,
|
||||
_BoundPromptReplacement,
|
||||
_PlaceholderInfo)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
@@ -306,11 +310,11 @@ def get_max_phi3v_image_tokens(
|
||||
*,
|
||||
num_crops: Optional[int] = None,
|
||||
) -> int:
|
||||
mm_processor_kwargs = {}
|
||||
hf_processor_mm_kwargs = {}
|
||||
if num_crops:
|
||||
mm_processor_kwargs["num_crops"] = num_crops
|
||||
hf_processor_mm_kwargs["num_crops"] = num_crops
|
||||
|
||||
processor = ctx.get_hf_processor(**mm_processor_kwargs)
|
||||
processor = ctx.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
return processor.calc_num_image_tokens_from_image_size(
|
||||
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
@@ -331,39 +335,50 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
input_ids = processed_outputs["input_ids"]
|
||||
assert isinstance(input_ids, torch.Tensor)
|
||||
|
||||
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
|
||||
# which will cause OverflowError when decoding the prompt_ids.
|
||||
# Therefore, we need to do an early replacement here
|
||||
token_ids = processed_outputs['input_ids']
|
||||
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
|
||||
processed_outputs['input_ids'] = token_ids
|
||||
input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_images = mm_config.limit_per_prompt.get("image", 1)
|
||||
tokenizer = self._get_tokenizer()
|
||||
bos_token_id = tokenizer.bos_token_id
|
||||
assert isinstance(bos_token_id, int)
|
||||
|
||||
def get_replacement_phi3v(item_idx: int):
|
||||
image_size = mm_items.get_image_size(item_idx)
|
||||
@@ -372,21 +387,44 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
height=image_size.height,
|
||||
)
|
||||
|
||||
return [_IMAGE_TOKEN_ID] * num_tokens
|
||||
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=image_token,
|
||||
replacement=get_replacement_phi3v,
|
||||
) for image_token in image_tokens[:max_images]
|
||||
) for image_token in image_tokens[:len(mm_items.images)]
|
||||
]
|
||||
|
||||
def _apply_prompt_replacements(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
|
||||
token_ids, text, placeholders = super()._apply_prompt_replacements(
|
||||
token_ids=token_ids,
|
||||
prompt_repls=prompt_repls,
|
||||
mm_item_counts=mm_item_counts,
|
||||
)
|
||||
|
||||
# Keep the behavior in line with HF processor
|
||||
if text.startswith("<s> <|image|>"):
|
||||
text = text.replace("<s> <|image|>", "<s><|image|>", 1)
|
||||
token_ids = [token_ids[0], *token_ids[2:]]
|
||||
placeholders = [
|
||||
_PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement)
|
||||
for p in placeholders
|
||||
]
|
||||
|
||||
return token_ids, text, placeholders
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts["image"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
data = dummy_image_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
@@ -401,9 +439,28 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return ProcessorInputs(
|
||||
prompt_text="".join(image_tokens[:num_images]),
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Only <|image|> tokens should be considered as placeholders,
|
||||
# so we ignore the trailing bos_token_id
|
||||
result["mm_placeholders"] = {
|
||||
modality: [
|
||||
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
|
||||
for p in ps
|
||||
]
|
||||
for modality, ps in result["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
|
||||
|
||||
Reference in New Issue
Block a user