[Model] Broadcast Ovis2 implementation to fit Ovis1.6 (#17861)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-05-12 08:56:30 +08:00
committed by GitHub
parent 7de18d541b
commit 021c16c7ca
16 changed files with 330 additions and 212 deletions

View File

@@ -512,7 +512,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config.image_token_index)
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
"internvl_chat", "ovis2", "skywork_chat",
"internvl_chat", "ovis", "skywork_chat",
"NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"):
return "<image>"
if model_type in ("mllama", "llama4"):

View File

@@ -5,129 +5,14 @@
from typing import Optional
import torch
from torch import nn, softmax
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.functional import gumbel_softmax, pad
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.transformers_utils.configs.ovis2 import (AIMv2Config,
Aimv2VisualTokenizerConfig)
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304,
-305] # kept for vocab prefixed tokens
def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(
y_soft, memory_format=torch.legacy_contiguous_format).scatter_(
dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
return ret
class Aimv2VisualTokenizer(torch.nn.Module):
def __init__(self,
config: Aimv2VisualTokenizerConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
**kwargs):
super().__init__()
self.config = config
self.backbone = AIMv2Model(
config=config.backbone_config, # noqa
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer")
# reserved tokens for IMAGE_INDICATORS
head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS)
self.head = torch.nn.Sequential(
ReplicatedLinear(
config.backbone_config.hidden_size * config.hidden_stride *
config.hidden_stride,
head_dim,
bias=False,
), torch.nn.LayerNorm(head_dim))
@property
def dtype(self):
return self.backbone.dtype
@property
def device(self):
return self.backbone.device
def tokenize(self, logits):
if self.config.tokenize_function == 'softmax':
tokens = softmax(logits, dim=-1)
elif self.config.tokenize_function == 'gumbel_argmax':
tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
elif self.config.tokenize_function == 'st_argmax':
tokens = st_argmax(logits, dim=-1)
else:
raise ValueError(
'Invalid `max_type`, expected softmax or gumbel_argmax '
f'or st_argmax, but got {self.config.tokenize_function}')
return tokens
def encode(self, pixel_values):
features = self.backbone(pixel_values)
if self.config.drop_cls_token:
features = features[:, 1:, :]
# merge number of `hidden_stride * hidden_stride` hidden states together
# to reduce token sequence length
# e.g., for hidden_stride=2, this leads to a token length reduction:
# 1024 -> 256 for aimv2
if self.config.hidden_stride > 1:
# this `d` maybe different from the above `d``
n, L, d = features.shape
sqrt_l = int(L**0.5)
assert sqrt_l**2 == L, (
"The token sequence length should be a perfect square.")
features = features.reshape(n, sqrt_l, sqrt_l, d)
pl = (self.config.hidden_stride -
(sqrt_l %
self.config.hidden_stride)) % self.config.hidden_stride
features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
sqrt_l += pl
features = features.reshape(n, sqrt_l // self.config.hidden_stride,
self.config.hidden_stride,
sqrt_l // self.config.hidden_stride,
self.config.hidden_stride, d)
# [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
features = features.permute(0, 1, 3, 2, 4, 5)
# [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
features = features.flatten(3)
# [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d]
features = features.reshape(
n, -1,
self.config.hidden_stride * self.config.hidden_stride * d)
return features
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]"""
features = self.encode(pixel_values)
logits, _ = self.head[0](
features) # we spllit the sequncial here for not throwing an error
logits = self.head[1](logits)
tokens = self.tokenize(logits)
# tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with
# [BatchSize, #Token, 5], after which, tokens' shape should become
# [BatchSize, #Token, VocabSize]
batch_size, token_len, _ = tokens.shape
padding_tensor = torch.zeros(size=(batch_size, token_len,
len(IMAGE_INDICATOR_IDS)),
dtype=tokens.dtype,
device=tokens.device,
layout=tokens.layout,
requires_grad=False)
tokens = torch.cat((tokens, padding_tensor), dim=2)
return tokens
from vllm.transformers_utils.configs.ovis import AIMv2Config
class AIMv2SwiGLUFFN(nn.Module):
@@ -302,14 +187,6 @@ class AIMv2Model(torch.nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.trunk")
@property
def dtype(self):
return self.trunk.blocks[0].attn.qkv.weight.dtype
@property
def device(self):
return self.trunk.blocks[0].attn.qkv.device
def forward(
self,
pixel_values: torch.Tensor,

View File

@@ -15,17 +15,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Ovis2 model."""
""" PyTorch Ovis model."""
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
from torch import Tensor
from transformers import BatchFeature
from torch.nn.functional import gumbel_softmax, pad, softmax
from transformers import BaseImageProcessor, BatchFeature
from vllm.config import VllmConfig
from vllm.model_executor.models.aimv2 import Aimv2VisualTokenizer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.models.aimv2 import AIMv2Model
from vllm.model_executor.models.siglip import SiglipVisionModel
from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn,
init_vllm_registered_model,
maybe_prefix)
@@ -38,19 +44,160 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ovis2 import OvisConfig
from vllm.transformers_utils.processors.ovis2 import OvisProcessor
from vllm.transformers_utils.configs.ovis import (BaseVisualTokenizerConfig,
OvisConfig)
from vllm.transformers_utils.processors.ovis import OvisProcessor
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .utils import merge_multimodal_embeddings
# Cannot find the following number from hf config.
IMAGE_TOKEN = "<image>"
IMAGE_PAD_TOKEN_ID = 151655
NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT = 256
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
IMAGE_PAD_TOKEN_MAP = {
"gemma2": "<unused0>",
"llama": "<|reserved_special_token_0|>",
"qwen2": "<|image_pad|>",
}
IMAGE_PAD_TOKEN_ID_MAP = {
"gemma2": 7,
"llama": 128002,
"qwen2": 151655,
}
class Ovis2ImagePatchInputs(TypedDict):
def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax
index = y_soft.argmax(dim, keepdim=True)
return torch.zeros_like(
y_soft,
memory_format=torch.legacy_contiguous_format,
).scatter_(dim, index, 1.0)
class VisualTokenizer(torch.nn.Module):
def __init__(
self,
config: BaseVisualTokenizerConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.backbone = self._init_backbone(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.backbone",
)
# reserved tokens for IMAGE_INDICATORS
head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS)
self.head = torch.nn.Sequential(
ReplicatedLinear(
config.backbone_config.hidden_size * config.hidden_stride *
config.hidden_stride,
head_dim,
bias=False,
return_bias=False,
), torch.nn.LayerNorm(head_dim))
def _init_backbone(
self,
config: BaseVisualTokenizerConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
model_type = config.backbone_config.model_type
if model_type == "aimv2":
return AIMv2Model(
config=config.backbone_config,
quant_config=quant_config,
prefix=prefix,
)
elif model_type == "siglip_vision_model":
return SiglipVisionModel(
config=config.backbone_config,
quant_config=quant_config,
prefix=prefix,
)
raise ValueError(
f"Unsupported visual tokenizer model_type: {model_type}")
@property
def dtype(self):
return next(self.head.parameters()).dtype
@property
def device(self):
return next(self.head.parameters()).device
def tokenize(self, logits):
if self.config.tokenize_function == 'softmax':
tokens = softmax(logits, dim=-1)
elif self.config.tokenize_function == 'gumbel_argmax':
tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
elif self.config.tokenize_function == 'st_argmax':
tokens = st_argmax(logits, dim=-1)
else:
raise ValueError(
'Invalid `max_type`, expected softmax or gumbel_argmax '
f'or st_argmax, but got {self.config.tokenize_function}')
return tokens
def encode(self, pixel_values):
features = self.backbone(pixel_values)
if self.config.drop_cls_token:
features = features[:, 1:, :]
# merge number of `hidden_stride * hidden_stride` hidden states together
# to reduce token sequence length
# e.g., for hidden_stride=2, this leads to a token length reduction:
# 1024 -> 256 for aimv2
if self.config.hidden_stride > 1:
# this `d` maybe different from the above `d``
n, L, d = features.shape
sqrt_l = int(L**0.5)
assert sqrt_l**2 == L, (
"The token sequence length should be a perfect square.")
features = features.reshape(n, sqrt_l, sqrt_l, d)
pl = (self.config.hidden_stride -
(sqrt_l %
self.config.hidden_stride)) % self.config.hidden_stride
features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
sqrt_l += pl
features = features.reshape(n, sqrt_l // self.config.hidden_stride,
self.config.hidden_stride,
sqrt_l // self.config.hidden_stride,
self.config.hidden_stride, d)
# [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
features = features.permute(0, 1, 3, 2, 4, 5)
# [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
features = features.flatten(3)
# [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d]
features = features.reshape(
n, -1,
self.config.hidden_stride * self.config.hidden_stride * d)
return features
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]"""
features = self.encode(pixel_values)
logits = self.head(features)
tokens = self.tokenize(logits)
# tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with
# [BatchSize, #Token, 5], after which, tokens' shape should become
# [BatchSize, #Token, VocabSize]
tokens = torch.nn.functional.pad(
tokens,
(0, len(IMAGE_INDICATOR_IDS)),
mode="constant",
value=0,
)
return tokens
class OvisImagePatchInputs(TypedDict):
type: Literal["image_patches"]
flat_data: torch.Tensor
"""
@@ -92,31 +239,50 @@ class VisualEmbedding(torch.nn.Embedding):
return self.weight.dtype
class Ovis2ProcessingInfo(BaseProcessingInfo):
class OvisProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(OvisConfig)
def get_hf_processor(self, **kwargs):
return self.ctx.get_hf_processor(OvisProcessor)
return self.ctx.get_hf_processor(
OvisProcessor,
image_pad_token=self.get_image_pad_token(),
image_segment_len=self.get_image_segment_len(),
)
def get_image_processor(self) -> OvisProcessor:
def get_image_segment_len(self) -> int:
visual_tokenizer_config = self.get_hf_config().visual_tokenizer_config
image_size = visual_tokenizer_config.backbone_config.image_size
patch_size = visual_tokenizer_config.backbone_config.patch_size
hidden_stride = visual_tokenizer_config.hidden_stride
patch_grid_length = math.ceil(image_size / patch_size)
assert patch_grid_length % hidden_stride == 0, (
f"patch_grid_length {patch_grid_length} is not divisible by "
f"hidden_stride {hidden_stride}")
# minus 1 for presented image token
return (patch_grid_length // hidden_stride)**2 - 1
def get_image_pad_token(self) -> str:
hf_text_config = self.get_hf_config().get_text_config()
text_model_type = hf_text_config.model_type
return IMAGE_PAD_TOKEN_MAP.get(text_model_type)
def get_image_processor(self) -> BaseImageProcessor:
return self.get_hf_processor().image_processor # type: ignore
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return { # 32k is model token limit at the moment
"image":
self.get_hf_config().multimodal_max_length //
((9 + 1) * NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT)
}
return {"image": None}
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
return ImageSize(width=image_processor.size['shortest_edge'] * 9 * 2,
height=image_processor.size['shortest_edge'] * 9 * 2)
height, width = self.get_hf_processor().get_image_size()
hs = self.get_hf_config().visual_tokenizer_config.hidden_stride
# NOTE(Isotr0py): 9 is `max_partion` hardcoded in original code
# https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/modeling_ovis.py#L96
return ImageSize(width=width * hs * 9, height=height * hs * 9)
class Ovis2DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2ProcessingInfo]):
class OvisDummyInputsBuilder(BaseDummyInputsBuilder[OvisProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
@@ -141,7 +307,7 @@ class Ovis2DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2ProcessingInfo]):
return mm_data
class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]):
class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
def image_indicators_to_visual_tokens(
self,
@@ -165,9 +331,9 @@ class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]):
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
# # Avoid warning from HF logger for text-only input
prompt_ids = self.info.get_tokenizer().encode(prompt)
# prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) nope
# Avoid warning from HF logger for text-only input
tokenizer = self.info.get_tokenizer()
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
processed_outputs = super()._call_hf_processor(
@@ -226,10 +392,10 @@ class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]):
]
@MULTIMODAL_REGISTRY.register_processor(Ovis2MultiModalProcessor,
info=Ovis2ProcessingInfo,
dummy_inputs=Ovis2DummyInputsBuilder)
class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
@MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor,
info=OvisProcessingInfo,
dummy_inputs=OvisDummyInputsBuilder)
class Ovis(nn.Module, SupportsMultiModal):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -242,24 +408,25 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
prefix=maybe_prefix(prefix, "llm"),
)
self.visual_tokenizer = Aimv2VisualTokenizer(
self.visual_tokenizer = VisualTokenizer(
config=config.visual_tokenizer_config,
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer",
image_processor_name_or_path=config.visual_tokenizer_config.
backbone_config.name_or_path,
)
self.vte = VisualEmbedding(
self.config.visual_tokenizer_config.vocab_size,
self.config.hidden_size)
text_model_type = self.config.get_text_config().model_type
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
# TODO(Isotr0py): PP support
# self.make_empty_intermediate_tensors = (
# self.language_model.make_empty_intermediate_tensors)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Ovis2ImagePatchInputs]:
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
pixel_values = kwargs.pop("pixel_values", None)
indicator_tokens = kwargs.pop("indicator_tokens", None)
@@ -275,7 +442,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of indicator_tokens. "
f"Got type: {type(pixel_values)}")
return Ovis2ImagePatchInputs(
return OvisImagePatchInputs(
type="image_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
patches_per_image=[
@@ -288,7 +455,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self, image_input: Ovis2ImagePatchInputs) -> MultiModalEmbeddings:
self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]
indicator_tokens = image_input["indicator_tokens"]
@@ -338,7 +505,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[IMAGE_PAD_TOKEN_ID])
self.image_pad_token_id)
return inputs_embeds
def forward(
@@ -375,8 +542,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.llm.logits_processor(self.llm.lm_head, hidden_states,
sampling_metadata)
logits = self.llm.compute_logits(hidden_states, sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,

View File

@@ -195,7 +195,7 @@ _MULTIMODAL_MODELS = {
"Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
"Ovis2ForConditionalGeneration": ("ovis2", "Ovis2ForConditionalGeneration"),
"Ovis": ("ovis", "Ovis"),
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501

View File

@@ -23,7 +23,7 @@ from vllm.transformers_utils.configs.moonvit import MoonViTConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.ovis2 import OvisConfig
from vllm.transformers_utils.configs.ovis import OvisConfig
from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
from vllm.transformers_utils.configs.solar import SolarConfig
from vllm.transformers_utils.configs.telechat2 import Telechat2Config

View File

@@ -123,6 +123,19 @@ class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig):
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig):
model_type = "siglip_visual_tokenizer"
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.drop_cls_token:
self.drop_cls_token = False
if self.depths:
assert len(self.depths) == 1
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
AutoConfig.register("siglip_visual_tokenizer", SiglipVisualTokenizerConfig)
AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig)

View File

@@ -2,6 +2,6 @@
from vllm.transformers_utils.processors.deepseek_vl2 import (
DeepseekVLV2Processor)
from vllm.transformers_utils.processors.ovis2 import OvisProcessor
from vllm.transformers_utils.processors.ovis import OvisProcessor
__all__ = ["DeepseekVLV2Processor", "OvisProcessor"]

View File

@@ -22,6 +22,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import List, Union
import PIL
@@ -32,7 +33,7 @@ from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin,
Unpack)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
__all__ = [ 'OvisProcessor']
__all__ = ['OvisProcessor']
IGNORE_ID = -100
class OvisProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg]
@@ -64,18 +65,29 @@ class OvisProcessor(ProcessorMixin):
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
valid_kwargs = ["chat_template", "image_pad_token", "image_segement_len"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "Qwen2Tokenizer"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, image_pad_token=None, **kwargs):
def __init__(
self,
image_processor=None,
tokenizer=None,
chat_template=None,
image_pad_token=None,
image_segment_len=255,
**kwargs,
):
self.image_token = "<image>"
self.image_pad_token = "<|image_pad|>" if image_pad_token is None else image_pad_token
self.image_pad_token = image_pad_token
self.image_segment_len = image_segment_len
super().__init__(image_processor, tokenizer, chat_template=chat_template)
self.image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token]
self.extra_special_tokens = {
@cached_property
def extra_special_tokens(self):
image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token]
extra_special_tokens = {
"image_token": -200,
"image_atom": -300,
"image_start": -301,
@@ -83,8 +95,9 @@ class OvisProcessor(ProcessorMixin):
"image_col_sep": -303,
"image_row_sep": -304,
"image_end": -305,
'image_pad': self.image_pad_token_id,
'image_pad': image_pad_token_id,
}
return extra_special_tokens
def __call__(
self,
@@ -224,8 +237,14 @@ class OvisProcessor(ProcessorMixin):
return torch.tensor(batch_token_ids, dtype=torch.long)
def get_image_size(self):
height = self.image_processor.crop_size["height"]
width = self.image_processor.crop_size["width"]
size = self.image_processor.size
if 'shortest_edge' in size:
width = height = size['shortest_edge']
elif "height" in size and "width" in size:
width = size['width']
height = size['height']
else:
raise ValueError( "Can't parse image size from image_processor config.")
return height, width
def get_token_value(self, tok):
@@ -259,8 +278,7 @@ class OvisProcessor(ProcessorMixin):
for token in image_placeholders:
padded_placeholder_tokens.append(image_padding_token_id)
if token == image_atom_token_id:
# Add 255 padding tokens after each image atom token
padded_placeholder_tokens.extend([image_padding_token_id] * 255)
padded_placeholder_tokens.extend([image_padding_token_id] * self.image_segment_len)
return padded_placeholder_tokens
def preprocess_image(self, image: PIL.Image.Image, max_partition, covering_threshold, convert_to_rgb, return_tensors):