diff --git a/.buildkite/test_areas/models_distributed.yaml b/.buildkite/test_areas/models_distributed.yaml index 9df1bf830..55e7410b8 100644 --- a/.buildkite/test_areas/models_distributed.yaml +++ b/.buildkite/test_areas/models_distributed.yaml @@ -18,5 +18,6 @@ steps: # Avoid importing model tests that cause CUDA reinitialization error - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' - pytest models/language -v -s -m 'distributed(num_gpus=2)' - - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py + - pytest models/multimodal/generation/test_phi4siglip.py -v -s -m 'distributed(num_gpus=2)' + - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/generation/test_phi4siglip.py - VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)' diff --git a/tests/models/multimodal/generation/test_phi4siglip.py b/tests/models/multimodal/generation/test_phi4siglip.py new file mode 100644 index 000000000..e8f4ba829 --- /dev/null +++ b/tests/models/multimodal/generation/test_phi4siglip.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence + +import pytest +import regex as re +from transformers import AutoModelForCausalLM, AutoTokenizer + +from vllm.logprobs import SampleLogprobs +from vllm.multimodal.image import rescale_image_size + +from ....conftest import ( + IMAGE_ASSETS, + HfRunner, + PromptImageInput, + VllmRunner, +) +from ....utils import multi_gpu_test +from ...utils import check_logprobs_close + +MODEL_ID = "microsoft/Phi-4-reasoning-vision-15B" + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "<|user|>\n\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 + "cherry_blossom": "<|user|>\n\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501 + } +) +HF_MULTIIMAGE_IMAGE_PROMPT = ( + "<|user|>\n\n\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 +) + +DTYPE = "half" +MAX_TOKENS = 128 +NUM_LOGPROBS = 10 + + +def vllm_to_hf_output( + vllm_output: tuple[list[int], str, SampleLogprobs | None], model: str +): + """Sanitize vllm output to be comparable with hf output.""" + _, output_str, out_logprobs = vllm_output + + output_str_without_image = re.sub(r"()+", "", output_str) + if output_str_without_image and output_str_without_image[0] == " ": + output_str_without_image = output_str_without_image[1:] + + hf_output_str = output_str_without_image + "<|end|><|endoftext|>" + + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + hf_output_ids = tokenizer.encode(output_str_without_image) + if hf_output_ids and hf_output_ids[0] == tokenizer.bos_token_id: + hf_output_ids = hf_output_ids[1:] + + return hf_output_ids, hf_output_str, out_logprobs + + +def _build_single_image_inputs( + image_assets, +) -> list[tuple[list[str], PromptImageInput]]: + """Build single-image inputs for all size_factors at once.""" + images = [asset.pil_image for asset in image_assets] + all_inputs: list[tuple[list[str], PromptImageInput]] = [] + for size_factors in [[1.0], [0.25, 0.5, 1.0]]: + for image, prompt in zip(images, HF_IMAGE_PROMPTS): + all_inputs.append( + ( + [prompt for _ in size_factors], + [rescale_image_size(image, f) for f in size_factors], + ) + ) + return all_inputs + + +def _build_multi_image_inputs( + image_assets, +) -> list[tuple[list[str], PromptImageInput]]: + """Build multi-image inputs for all size_factors at once.""" + images = [asset.pil_image for asset in image_assets] + all_inputs: list[tuple[list[str], PromptImageInput]] = [] + for size_factors in [[0.5], [0.15, 0.30]]: + all_inputs.append( + ( + [HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], + [ + [rescale_image_size(image, factor) for image in images] + for factor in size_factors + ], + ) + ) + return all_inputs + + +def _run_and_compare( + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + all_inputs: Sequence[tuple[list[str], PromptImageInput]], + model: str, + max_model_len: int, + max_num_seqs: int, + mm_limit: int, + gpu_memory_utilization: float, +): + """Load each runner once, run all inputs, then compare.""" + # NOTE: run vLLM first, then HF. vLLM needs a fresh process without + # cuda initialization; running HF first would break the multiprocessing + # backend with fork method. + with vllm_runner( + model, + runner="generate", + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, + gpu_memory_utilization=gpu_memory_utilization, + dtype=DTYPE, + limit_mm_per_prompt={"image": mm_limit}, + tensor_parallel_size=2, + trust_remote_code=True, + enforce_eager=True, + ) as vllm_model: + vllm_outputs_per_case = [ + vllm_model.generate_greedy_logprobs( + prompts, + MAX_TOKENS, + num_logprobs=NUM_LOGPROBS, + images=images, + ) + for prompts, images in all_inputs + ] + + hf_model_kwargs = {"_attn_implementation": "sdpa", "device_map": "auto"} + with hf_runner( + model, + dtype=DTYPE, + model_kwargs=hf_model_kwargs, + auto_cls=AutoModelForCausalLM, + trust_remote_code=True, + ) as hf_model: + hf_outputs_per_case = [ + hf_model.generate_greedy_logprobs_limit( + prompts, + MAX_TOKENS, + num_logprobs=NUM_LOGPROBS, + images=images, + ) + for prompts, images in all_inputs + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("model", [MODEL_ID]) +def test_models(hf_runner, vllm_runner, image_assets, model) -> None: + all_inputs = _build_single_image_inputs(image_assets) + _run_and_compare( + hf_runner, + vllm_runner, + all_inputs, + model, + max_model_len=8192, + max_num_seqs=2, + mm_limit=1, + gpu_memory_utilization=0.80, + ) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("model", [MODEL_ID]) +def test_multi_images_models(hf_runner, vllm_runner, image_assets, model) -> None: + all_inputs = _build_multi_image_inputs(image_assets) + _run_and_compare( + hf_runner, + vllm_runner, + all_inputs, + model, + max_model_len=8192, + max_num_seqs=2, + mm_limit=2, + gpu_memory_utilization=0.80, + ) diff --git a/tests/models/registry.py b/tests/models/registry.py index acfc4786e..e8b0f2191 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -1049,6 +1049,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { }, # noqa: E501 extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}, ), + "Phi4ForCausalLMV": _HfExamplesInfo( + "microsoft/Phi-4-reasoning-vision-15B", trust_remote_code=True + ), "Phi4MMForCausalLM": _HfExamplesInfo( "microsoft/Phi-4-multimodal-instruct", trust_remote_code=True ), diff --git a/vllm/model_executor/models/phi4siglip.py b/vllm/model_executor/models/phi4siglip.py new file mode 100644 index 000000000..d71a572f6 --- /dev/null +++ b/vllm/model_executor/models/phi4siglip.py @@ -0,0 +1,429 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""vLLM support for microsoft/Phi-4-reasoning-vision-15B. + +Architecture: Siglip2 vision tower + MLP projector + Phi3 language model. +""" + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Any, Literal + +import torch +import torch.nn as nn +from transformers import BatchFeature, PretrainedConfig, Siglip2VisionConfig + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + PromptReplacement, + PromptUpdate, +) +from vllm.multimodal.processing.processor import ( + BaseMultiModalProcessor, + BaseProcessingInfo, +) +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .lfm2_siglip2 import Siglip2Model +from .llava import LlavaMultiModalProjector +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) + +logger = init_logger(__name__) + +IMAGE_TOKEN_INDEX = -200 +DEFAULT_IMAGE_TOKEN = "" + +# The HF processor replaces "" with IMAGE_TOKEN_INDEX (-200) in input_ids. +# Negative token IDs cause OverflowError during decoding, so we remap to a real +# in-vocabulary token. The Phi-4-reasoning-vision tokenizer ships with reserved +# dummy tokens (<|dummy_0|> … <|dummy_83|>); we reuse the first one as the +# image placeholder. This mirrors how Phi-3-vision uses its dedicated <|image|> +# token (ID 32044). +_IMAGE_TOKEN_ID = 100256 # <|dummy_0|> in the Phi-4 tokenizer + + +# --------------------------------------------------------------------------- +# Processing +# --------------------------------------------------------------------------- + + +class Phi4SiglipProcessingInfo(BaseProcessingInfo): + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None} + + def _get_vision_config(self) -> dict: + return self.get_hf_config().vision_config # type: ignore[attr-defined] + + def _get_patch_size(self) -> int: + vc = self._get_vision_config() + if isinstance(vc, dict): + return vc.get("patch_size", 16) + return getattr(vc, "patch_size", 16) + + def _get_max_num_patches(self) -> int: + return getattr(self.get_hf_config(), "max_num_patches", 3600) + + def _get_min_num_patches(self) -> int: + return getattr(self.get_hf_config(), "min_num_patches", 256) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + patch_size = self._get_patch_size() + min_patches = self._get_min_num_patches() + max_patches = self._get_max_num_patches() + + num_patches_h = image_height // patch_size + num_patches_w = image_width // patch_size + num_patches = max(num_patches_h * num_patches_w, 1) + num_patches = max(min(num_patches, max_patches), min_patches) + return num_patches + + def get_image_size_with_most_features(self) -> ImageSize: + patch_size = self._get_patch_size() + max_patches = self._get_max_num_patches() + side = int(math.sqrt(max_patches)) * patch_size + return ImageSize(width=side, height=side) + + def get_mm_max_tokens_per_item( + self, seq_len: int, mm_counts: Mapping[str, int] + ) -> Mapping[str, int]: + return {"image": self._get_max_num_patches()} + + +class Phi4SiglipDummyInputsBuilder( + BaseDummyInputsBuilder[Phi4SiglipProcessingInfo], +): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + return DEFAULT_IMAGE_TOKEN * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + size = self.info.get_image_size_with_most_features() + return { + "image": self._get_dummy_images( + width=size.width, + height=size.height, + num_images=num_images, + overrides=mm_options.get("image"), + ), + } + + +class Phi4SiglipMultiModalProcessor( + BaseMultiModalProcessor[Phi4SiglipProcessingInfo], +): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + # The HF processor's tokenizer_image_token() replaces the "" + # string with IMAGE_TOKEN_INDEX (-200) in input_ids. This breaks + # vLLM's prompt-replacement pipeline which needs to find "" + # as normal sub-tokens. Re-tokenize with the plain tokenizer so + # that "" stays as sub-tokens and can be located by + # PromptReplacement. + # NOTE: tokenizer.__call__() (not .encode()) must be used so that + # added/special tokens like <|user|>, <|end|> are kept as single IDs. + tokenizer = self.info.get_tokenizer() + new_ids = tokenizer(prompt).input_ids + processed["input_ids"] = torch.tensor([new_ids]) + + return processed + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + # The HF processor replaces "" with a single -200 placeholder + # but does NOT expand it into N vision-encoder tokens. Since we also + # re-tokenize the prompt (see _call_hf_processor), prompt updates are + # never applied by the HF processor — vLLM handles the expansion via + # _apply_prompt_updates. + return False + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + pixel_attention_mask=MultiModalFieldConfig.batched("image"), + spatial_shapes=MultiModalFieldConfig.batched("image", keep_on_cpu=True), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + def get_replacement(item_idx: int): + # Read the actual patch grid from the NaFlex processor's + # spatial_shapes output (same pattern as LFM2-VL). This avoids + # predicting from raw image dimensions, which can diverge from + # the NaFlex resize/tile logic. + out_item = out_mm_kwargs["image"][item_idx] + spatial_shapes = out_item["spatial_shapes"].data + assert isinstance(spatial_shapes, torch.Tensor) + num_tokens = int(spatial_shapes.prod().item()) + return [_IMAGE_TOKEN_ID] * num_tokens + + return [ + PromptReplacement( + modality="image", + target=DEFAULT_IMAGE_TOKEN, + replacement=get_replacement, + ), + ] + + +# --------------------------------------------------------------------------- +# Input schemas +# --------------------------------------------------------------------------- + + +class Phi4SiglipImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - d: Max number of patches (padded across images in the batch) + - fd: Features per patch (patch_size * patch_size * channels) + """ + + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("bn", "d", "fd")] + pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bn", "d")] + spatial_shapes: Annotated[torch.Tensor, TensorShape("bn", 2)] + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + + +@MULTIMODAL_REGISTRY.register_processor( + Phi4SiglipMultiModalProcessor, + info=Phi4SiglipProcessingInfo, + dummy_inputs=Phi4SiglipDummyInputsBuilder, +) +class Phi4ForCausalLMV(nn.Module, SupportsMultiModal, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.vision_tower.vision_tower.vision_model.head.": None, + "model.vision_tower.vision_tower.": "vision_tower.", + "model.mm_projector.0.": "multi_modal_projector.linear_1.", + "model.mm_projector.2.": "multi_modal_projector.linear_2.", + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }, + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return DEFAULT_IMAGE_TOKEN + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + + config: PretrainedConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + + vision_config_dict: dict = getattr(config, "vision_config", {}) + if isinstance(vision_config_dict, dict): + if "patch_size" not in vision_config_dict: + vision_config_dict["patch_size"] = 16 + siglip2_config = Siglip2VisionConfig(**vision_config_dict) + else: + siglip2_config = vision_config_dict + + vision_hidden_size: int = config.mm_hidden_size # type: ignore[attr-defined] + text_hidden_size: int = config.hidden_size # type: ignore[attr-defined] + + with self._mark_tower_model(vllm_config, "image"): + layer_idx = -2 + num_hidden_layers = siglip2_config.num_hidden_layers + layer_idx + 1 + + self.vision_tower = Siglip2Model( + siglip2_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower"), + ) + self.multi_modal_projector = LlavaMultiModalProjector( + vision_hidden_size=vision_hidden_size, + text_hidden_size=text_hidden_size, + projector_hidden_act="gelu", + multimodal_projector_bias=True, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) + + with self._mark_language_model(vllm_config): + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Phi3ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + self.configure_mm_token_handling( + vocab_size=config.vocab_size, # type: ignore[attr-defined] + mm_token_ids=[_IMAGE_TOKEN_ID], + ) + + def _packed_from_padded( + self, + pixel_values: torch.Tensor, + pixel_attention_mask: torch.Tensor, + spatial_shapes: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert padded NaFlex tensors to packed format for Siglip2Model.""" + valid_counts = pixel_attention_mask.sum(dim=1).to(torch.int32) + pixel_values_packed = pixel_values[pixel_attention_mask.bool()] + cu_seqlens = torch.zeros( + len(valid_counts) + 1, + dtype=torch.int32, + device=pixel_values.device, + ) + cu_seqlens[1:] = valid_counts.cumsum(0) + max_seqlen = valid_counts.max() + return ( + pixel_values_packed, + spatial_shapes, + cu_seqlens, + max_seqlen, + ) + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Phi4SiglipImagePixelInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + pixel_attention_mask = kwargs.pop("pixel_attention_mask", None) + spatial_shapes = kwargs.pop("spatial_shapes", None) + if pixel_values is None: + return None + + return Phi4SiglipImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + ) + + def _process_image_input( + self, image_input: Phi4SiglipImagePixelInputs + ) -> MultiModalEmbeddings: + pixel_values = image_input["pixel_values"] + pixel_attention_mask = image_input["pixel_attention_mask"] + spatial_shapes = image_input["spatial_shapes"] + + ( + pixel_values_packed, + spatial_shapes_packed, + cu_seqlens, + max_seqlen, + ) = self._packed_from_padded(pixel_values, pixel_attention_mask, spatial_shapes) + + vision_features = self.vision_tower( + pixel_values_packed=pixel_values_packed, + spatial_shapes=spatial_shapes_packed, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + select_layers=[-2], + ) + + if vision_features.dim() == 3: + vision_features = vision_features.squeeze(0) + + image_features = self.multi_modal_projector(vision_features) + + valid_counts = pixel_attention_mask.sum(dim=1).tolist() + return torch.split(image_features, [int(c) for c in valid_counts]) + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + return self._process_image_input(image_input) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index d52a3e48a..1901381cb 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -481,6 +481,7 @@ _MULTIMODAL_MODELS = { "PaliGemmaForConditionalGeneration", ), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), + "Phi4ForCausalLMV": ("phi4siglip", "Phi4ForCausalLMV"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),