[Model] Add smolvlm support (#16017)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -206,6 +206,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
return grid_w * grid_h + 1
|
||||
|
||||
def _get_image_token(
|
||||
self,
|
||||
processor: Optional[Idefics3Processor]) -> tuple[str, str, str]:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
image_token = processor.image_token.content
|
||||
fake_image_token = processor.fake_image_token.content
|
||||
global_image_token = processor.global_image_tag
|
||||
return image_token, fake_image_token, global_image_token
|
||||
|
||||
def get_image_repl(
|
||||
self,
|
||||
*,
|
||||
@@ -216,9 +226,8 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
image_token = processor.image_token.content
|
||||
fake_image_token = processor.fake_image_token.content
|
||||
global_img_token = processor.global_image_tag
|
||||
image_token, fake_image_token, global_img_token = self._get_image_token(
|
||||
processor)
|
||||
image_seq_len = processor.image_seq_len
|
||||
grid_placeholder = "<row_{n_h}_col_{n_w}>"
|
||||
|
||||
@@ -300,7 +309,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
||||
longest_edge = image_processor.max_image_size['longest_edge']
|
||||
image_token = hf_processor.image_token.content
|
||||
image_token, _, _ = self.info._get_image_token(hf_processor)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
@@ -382,7 +391,7 @@ class Idefics3MultiModalProcessor(
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_token = hf_processor.image_token.content
|
||||
image_token, _, _ = self.info._get_image_token(hf_processor)
|
||||
|
||||
def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
|
||||
@@ -175,6 +175,7 @@ _MULTIMODAL_MODELS = {
|
||||
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
||||
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
|
||||
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
||||
|
||||
51
vllm/model_executor/models/smolvlm.py
Normal file
51
vllm/model_executor/models/smolvlm.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
from transformers import SmolVLMProcessor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
# yapf: disable
|
||||
from .idefics3 import Idefics3DummyInputsBuilder as SmolVLMDummyInputsBuilder
|
||||
from .idefics3 import Idefics3ForConditionalGeneration
|
||||
from .idefics3 import Idefics3MultiModalProcessor as SmolVLMMultiModalProcessor
|
||||
from .idefics3 import Idefics3ProcessingInfo
|
||||
|
||||
# yapf: enable
|
||||
|
||||
|
||||
class SmolVLMProcessingInfo(Idefics3ProcessingInfo):
|
||||
|
||||
def get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
max_image_size: Optional[Dict[str, int]] = None,
|
||||
**kwargs: object,
|
||||
) -> SmolVLMProcessor:
|
||||
if max_image_size is not None:
|
||||
kwargs["max_image_size"] = max_image_size
|
||||
|
||||
return self.ctx.get_hf_processor(SmolVLMProcessor, **kwargs)
|
||||
|
||||
def _get_image_token(
|
||||
self, processor: Optional[SmolVLMProcessor]) -> tuple[str, str]:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
image_token = processor.image_token
|
||||
fake_image_token = processor.fake_image_token
|
||||
global_image_token = processor.global_image_token
|
||||
return image_token, fake_image_token, global_image_token
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(SmolVLMMultiModalProcessor,
|
||||
info=SmolVLMProcessingInfo,
|
||||
dummy_inputs=SmolVLMDummyInputsBuilder)
|
||||
class SmolVLMForConditionalGeneration(Idefics3ForConditionalGeneration):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
Reference in New Issue
Block a user