[Model] Broadcast Ovis2 implementation to fit Ovis1.6 (#17861)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-05-12 08:56:30 +08:00
committed by GitHub
parent 7de18d541b
commit 021c16c7ca
16 changed files with 330 additions and 212 deletions

View File

@@ -5,129 +5,14 @@
from typing import Optional
import torch
from torch import nn, softmax
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.functional import gumbel_softmax, pad
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.transformers_utils.configs.ovis2 import (AIMv2Config,
Aimv2VisualTokenizerConfig)
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304,
-305] # kept for vocab prefixed tokens
def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(
y_soft, memory_format=torch.legacy_contiguous_format).scatter_(
dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
return ret
class Aimv2VisualTokenizer(torch.nn.Module):
def __init__(self,
config: Aimv2VisualTokenizerConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
**kwargs):
super().__init__()
self.config = config
self.backbone = AIMv2Model(
config=config.backbone_config, # noqa
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer")
# reserved tokens for IMAGE_INDICATORS
head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS)
self.head = torch.nn.Sequential(
ReplicatedLinear(
config.backbone_config.hidden_size * config.hidden_stride *
config.hidden_stride,
head_dim,
bias=False,
), torch.nn.LayerNorm(head_dim))
@property
def dtype(self):
return self.backbone.dtype
@property
def device(self):
return self.backbone.device
def tokenize(self, logits):
if self.config.tokenize_function == 'softmax':
tokens = softmax(logits, dim=-1)
elif self.config.tokenize_function == 'gumbel_argmax':
tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
elif self.config.tokenize_function == 'st_argmax':
tokens = st_argmax(logits, dim=-1)
else:
raise ValueError(
'Invalid `max_type`, expected softmax or gumbel_argmax '
f'or st_argmax, but got {self.config.tokenize_function}')
return tokens
def encode(self, pixel_values):
features = self.backbone(pixel_values)
if self.config.drop_cls_token:
features = features[:, 1:, :]
# merge number of `hidden_stride * hidden_stride` hidden states together
# to reduce token sequence length
# e.g., for hidden_stride=2, this leads to a token length reduction:
# 1024 -> 256 for aimv2
if self.config.hidden_stride > 1:
# this `d` maybe different from the above `d``
n, L, d = features.shape
sqrt_l = int(L**0.5)
assert sqrt_l**2 == L, (
"The token sequence length should be a perfect square.")
features = features.reshape(n, sqrt_l, sqrt_l, d)
pl = (self.config.hidden_stride -
(sqrt_l %
self.config.hidden_stride)) % self.config.hidden_stride
features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
sqrt_l += pl
features = features.reshape(n, sqrt_l // self.config.hidden_stride,
self.config.hidden_stride,
sqrt_l // self.config.hidden_stride,
self.config.hidden_stride, d)
# [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
features = features.permute(0, 1, 3, 2, 4, 5)
# [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
features = features.flatten(3)
# [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d]
features = features.reshape(
n, -1,
self.config.hidden_stride * self.config.hidden_stride * d)
return features
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]"""
features = self.encode(pixel_values)
logits, _ = self.head[0](
features) # we spllit the sequncial here for not throwing an error
logits = self.head[1](logits)
tokens = self.tokenize(logits)
# tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with
# [BatchSize, #Token, 5], after which, tokens' shape should become
# [BatchSize, #Token, VocabSize]
batch_size, token_len, _ = tokens.shape
padding_tensor = torch.zeros(size=(batch_size, token_len,
len(IMAGE_INDICATOR_IDS)),
dtype=tokens.dtype,
device=tokens.device,
layout=tokens.layout,
requires_grad=False)
tokens = torch.cat((tokens, padding_tensor), dim=2)
return tokens
from vllm.transformers_utils.configs.ovis import AIMv2Config
class AIMv2SwiGLUFFN(nn.Module):
@@ -302,14 +187,6 @@ class AIMv2Model(torch.nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.trunk")
@property
def dtype(self):
return self.trunk.blocks[0].attn.qkv.weight.dtype
@property
def device(self):
return self.trunk.blocks[0].attn.qkv.device
def forward(
self,
pixel_values: torch.Tensor,