[VLM] Support caching in merged multi-modal processor (#11396)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-28 01:22:48 +08:00
committed by GitHub
parent 5ce4627a7e
commit 101418096f
20 changed files with 1459 additions and 452 deletions

View File

@@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@@ -32,10 +32,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalInputsV2,
MultiModalKwargs, NestedTensors,
PlaceholderRange)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
ProcessorInputs, PromptReplacement,
_BoundPromptReplacement,
_PlaceholderInfo)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@@ -306,11 +310,11 @@ def get_max_phi3v_image_tokens(
*,
num_crops: Optional[int] = None,
) -> int:
mm_processor_kwargs = {}
hf_processor_mm_kwargs = {}
if num_crops:
mm_processor_kwargs["num_crops"] = num_crops
hf_processor_mm_kwargs["num_crops"] = num_crops
processor = ctx.get_hf_processor(**mm_processor_kwargs)
processor = ctx.get_hf_processor(**hf_processor_mm_kwargs)
return processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
@@ -331,39 +335,50 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
input_ids = processed_outputs["input_ids"]
assert isinstance(input_ids, torch.Tensor)
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
# which will cause OverflowError when decoding the prompt_ids.
# Therefore, we need to do an early replacement here
token_ids = processed_outputs['input_ids']
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
processed_outputs['input_ids'] = token_ids
input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
image_processor = hf_processor.image_processor # type: ignore
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
tokenizer = self._get_tokenizer()
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
def get_replacement_phi3v(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
@@ -372,21 +387,44 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
height=image_size.height,
)
return [_IMAGE_TOKEN_ID] * num_tokens
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_phi3v,
) for image_token in image_tokens[:max_images]
) for image_token in image_tokens[:len(mm_items.images)]
]
def _apply_prompt_replacements(
self,
token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids,
prompt_repls=prompt_repls,
mm_item_counts=mm_item_counts,
)
# Keep the behavior in line with HF processor
if text.startswith("<s> <|image|>"):
text = text.replace("<s> <|image|>", "<s><|image|>", 1)
token_ids = [token_ids[0], *token_ids[2:]]
placeholders = [
_PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement)
for p in placeholders
]
return token_ids, text, placeholders
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts["image"]
num_images = mm_counts.get("image", 0)
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
@@ -401,9 +439,28 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
return ProcessorInputs(
prompt_text="".join(image_tokens[:num_images]),
mm_data=data,
mm_processor_kwargs={},
)
def apply(
self,
prompt_text: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
# Only <|image|> tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)