[Model] Port deepseek-vl2 processor, remove dependency (#12169)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-01-18 13:59:39 +08:00
committed by GitHub
parent 813f249f02
commit 02798ecabe
8 changed files with 385 additions and 49 deletions

View File

@@ -1,7 +1,7 @@
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
import math
from functools import cached_property, partial
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
@@ -9,7 +9,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import AutoProcessor, BatchFeature, ProcessorMixin
from transformers import BatchFeature
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
@@ -31,6 +31,8 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
MlpProjectorConfig,
VisionEncoderConfig)
from vllm.transformers_utils.processors.deepseek_vl2 import (
DeepseekVLV2Processor)
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
@@ -129,25 +131,8 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(DeepseekVLV2Config)
def get_hf_processor(self) -> ProcessorMixin:
# TODO(Isotr0py): we should get rid of dependency on deepseek_vl2
# in the future, because it's flasky and lack of maintenance.
try:
from deepseek_vl2.models.processing_deepseek_vl_v2 import (
DeepseekVLV2Processor, select_best_resolution)
AutoProcessor.register("DeepseekVLV2Processor",
DeepseekVLV2Processor)
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"You need to `pip install "
"git+https://github.com/deepseek-ai/DeepSeek-VL2.git` "
"to use this model") from exc
processor = self.ctx.get_hf_processor(DeepseekVLV2Processor)
processor.select_best_resolution = partial(
select_best_resolution,
candidate_resolutions=processor.candidate_resolutions)
return processor
def get_hf_processor(self) -> DeepseekVLV2Processor:
return self.ctx.get_hf_processor(DeepseekVLV2Processor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
@@ -224,31 +209,21 @@ class DeepseekVL2MultiModalProcessor(
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
outputs = self.info.ctx.call_hf_processor(
processed_outputs = self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(prompt=prompt, **mm_data),
mm_kwargs,
)
# Deepseek-vl2 processor don't return BatchFeature,
# we need to manually create it
processed_outputs = dict(input_ids=outputs["input_ids"])
processed_outputs = BatchFeature(data=dict(processed_outputs),
tensor_type="pt")
# Remove batch dimension from processor outputs,
# because we will try batch to create NestedTensors
target_dtype = self.info.ctx.model_config.dtype
pixel_values = outputs["images"].to(target_dtype).squeeze(0)
images_spatial_crop = outputs["images_spatial_crop"].squeeze(0)
pixel_values = processed_outputs.pop("pixel_values").to(
target_dtype)
# split pixel values into patches corresponding to each image
images_spatial_crop = processed_outputs["images_spatial_crop"]
patches_per_image = [
x.prod().item() + 1 for x in images_spatial_crop
]
# Rename `images` -> `pixel_values` to avoid confusion
processed_outputs["pixel_values"] = list(
pixel_values.split(patches_per_image))
processed_outputs["images_spatial_crop"] = images_spatial_crop
pixel_values = pixel_values.split(patches_per_image)
processed_outputs["pixel_values"] = pixel_values
else:
tokenizer = self.info.get_tokenizer()
processed_outputs = tokenizer(prompt,