[VLM] Move supported limits and max tokens to merged multi-modal processor (#11669)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Cyrus Leung
2025-01-01 23:44:42 +08:00
committed by GitHub
parent 73001445fb
commit a115ac46b5
16 changed files with 351 additions and 361 deletions

View File

@@ -23,7 +23,6 @@ from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@@ -306,25 +305,32 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline
def get_max_phi3v_image_tokens(
ctx: InputContext,
*,
num_crops: Optional[int] = None,
) -> int:
hf_processor_mm_kwargs = {}
if num_crops:
hf_processor_mm_kwargs["num_crops"] = num_crops
processor = ctx.get_hf_processor(**hf_processor_mm_kwargs)
return processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def _get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
processor = self._get_hf_processor()
return processor.calc_num_image_tokens_from_image_size( # type: ignore
width=image_width,
height=image_height,
)
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
max_image_tokens = self._get_num_image_tokens(
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
return {"image": max_image_tokens}
def _get_hf_processor(
self,
*,
@@ -332,6 +338,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
) -> ProcessorMixin:
if num_crops is not None:
return self.ctx.get_hf_processor(num_crops=num_crops)
return self.ctx.get_hf_processor()
def _call_hf_processor(
@@ -375,7 +382,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
image_processor = hf_processor.image_processor # type: ignore
tokenizer = self._get_tokenizer()
bos_token_id = tokenizer.bos_token_id
@@ -385,9 +391,9 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=image_size.width,
height=image_size.height,
num_tokens = self._get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
@@ -467,7 +473,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
return result
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(