Signed-off-by: Wangbei25 <wangbei41@huawie.com> Signed-off-by: Wangbei25 <wangbei41@huawei.com> Co-authored-by: Wangbei25 <wangbei41@huawie.com>
290 lines
9.2 KiB
Python
290 lines
9.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
||
# adapted from
|
||
# https://github.com/deepseek-ai/DeepSeek-OCR-2/blob/main/DeepSeek-OCR2-master/DeepSeek-OCR2-vllm/deepencoderv2/qwen2_d2e.py
|
||
|
||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
# All rights reserved.
|
||
|
||
# This source code is licensed under the license found in the
|
||
# LICENSE file in the root directory of this source tree.
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import transformers
|
||
|
||
from vllm.model_executor.custom_op import PluggableLayer
|
||
|
||
|
||
# --8<-- [start:qwen2_decoder]
|
||
@PluggableLayer.register("qwen2_decoder")
|
||
class CustomQwen2Decoder(PluggableLayer):
|
||
"""
|
||
Qwen2 visual encoder
|
||
non-causal attention + causal attention
|
||
token_type_ids :0=non-causal, 1=causal
|
||
"""
|
||
|
||
# --8<-- [end:qwen2_decoder]
|
||
|
||
def __init__(
|
||
self,
|
||
decoder_layer: int = 24,
|
||
max_position_embeddings: int = 131072,
|
||
hidden_dimension: int = 896,
|
||
num_attention_heads: int = 14,
|
||
num_key_value_heads: int = 2,
|
||
intermediate_size: int = 4864,
|
||
vocab_size: int = 151936,
|
||
attn_implementation: str = "sdpa",
|
||
rms_norm_eps: float = 1e-06,
|
||
rope_theta: float = 1000000.0,
|
||
attention_dropout: float = 0.0,
|
||
hidden_act: str = "silu",
|
||
initializer_range: float = 0.02,
|
||
):
|
||
super().__init__()
|
||
|
||
# load
|
||
Qwen2Model = transformers.models.qwen2.modeling_qwen2.Qwen2Model
|
||
Qwen2Config = transformers.Qwen2Config
|
||
|
||
# config
|
||
config = Qwen2Config(
|
||
hidden_size=hidden_dimension,
|
||
num_hidden_layers=decoder_layer,
|
||
num_attention_heads=num_attention_heads,
|
||
num_key_value_heads=num_key_value_heads,
|
||
intermediate_size=intermediate_size,
|
||
max_position_embeddings=max_position_embeddings,
|
||
vocab_size=vocab_size,
|
||
rms_norm_eps=rms_norm_eps,
|
||
rope_theta=rope_theta,
|
||
attention_dropout=attention_dropout,
|
||
hidden_act=hidden_act,
|
||
initializer_range=initializer_range,
|
||
_attn_implementation=attn_implementation, # ⭐
|
||
)
|
||
|
||
#
|
||
self.model = self._create_custom_model(Qwen2Model, config)
|
||
|
||
del self.model.embed_tokens
|
||
|
||
def _create_custom_model(self, Qwen2Model, config):
|
||
"""Qwen2Model"""
|
||
|
||
class CustomQwen2ModelInner(Qwen2Model):
|
||
def forward(
|
||
self,
|
||
input_ids=None,
|
||
attention_mask=None,
|
||
position_ids=None,
|
||
past_key_values=None,
|
||
inputs_embeds=None,
|
||
token_type_ids=None, # ⭐
|
||
use_cache=None,
|
||
output_attentions=None,
|
||
output_hidden_states=None,
|
||
return_dict=None,
|
||
cache_position=None,
|
||
):
|
||
# token_type_ids
|
||
self._current_token_type_ids = token_type_ids
|
||
causal_mask_mapping = {
|
||
"full_attention": self._update_causal_mask(
|
||
attention_mask,
|
||
inputs_embeds,
|
||
cache_position,
|
||
past_key_values,
|
||
output_attentions,
|
||
)
|
||
}
|
||
outputs = super().forward(
|
||
input_ids=input_ids,
|
||
attention_mask=causal_mask_mapping,
|
||
position_ids=position_ids,
|
||
past_key_values=past_key_values,
|
||
inputs_embeds=inputs_embeds,
|
||
use_cache=use_cache,
|
||
output_attentions=output_attentions,
|
||
output_hidden_states=output_hidden_states,
|
||
return_dict=return_dict,
|
||
cache_position=cache_position,
|
||
)
|
||
|
||
return outputs
|
||
|
||
def _update_causal_mask(
|
||
self,
|
||
attention_mask,
|
||
input_tensor,
|
||
cache_position,
|
||
past_key_values,
|
||
output_attentions,
|
||
):
|
||
dtype, device = input_tensor.dtype, input_tensor.device
|
||
min_dtype = torch.finfo(dtype).min
|
||
batch_size, sequence_length = (
|
||
input_tensor.shape[0],
|
||
input_tensor.shape[1],
|
||
)
|
||
|
||
token_type_ids = self._current_token_type_ids
|
||
|
||
# attention mask
|
||
causal_mask = self._create_custom_4d_mask(
|
||
sequence_length=sequence_length,
|
||
dtype=dtype,
|
||
device=device,
|
||
batch_size=batch_size,
|
||
token_type_ids=token_type_ids,
|
||
)
|
||
|
||
# padding mask
|
||
if attention_mask is not None and attention_mask.dim() == 2:
|
||
padding_mask = attention_mask[:, None, None, :].to(dtype=dtype)
|
||
padding_mask = (1.0 - padding_mask) * min_dtype
|
||
causal_mask = causal_mask + padding_mask
|
||
|
||
return causal_mask
|
||
|
||
def _create_custom_4d_mask(
|
||
self,
|
||
sequence_length,
|
||
dtype,
|
||
device,
|
||
batch_size,
|
||
token_type_ids,
|
||
):
|
||
min_dtype = torch.finfo(dtype).min
|
||
|
||
masks = []
|
||
for b in range(batch_size):
|
||
mask = torch.full(
|
||
(sequence_length, sequence_length),
|
||
fill_value=min_dtype,
|
||
dtype=dtype,
|
||
device=device,
|
||
)
|
||
|
||
type_ids = token_type_ids[b]
|
||
|
||
image_positions = (type_ids == 0).nonzero(as_tuple=True)[0]
|
||
text_positions = (type_ids == 1).nonzero(as_tuple=True)[0]
|
||
|
||
# non-casual
|
||
if len(image_positions) > 0:
|
||
mask[image_positions[:, None], image_positions] = 0.0
|
||
|
||
# causal
|
||
for i, text_pos in enumerate(text_positions):
|
||
if len(image_positions) > 0:
|
||
mask[text_pos, image_positions] = 0.0
|
||
mask[text_pos, text_positions[: i + 1]] = 0.0
|
||
|
||
masks.append(mask)
|
||
|
||
mask = torch.stack(masks, dim=0).unsqueeze(1)
|
||
return mask
|
||
|
||
return CustomQwen2ModelInner(config)
|
||
|
||
def forward(
|
||
self,
|
||
inputs_embeds: torch.Tensor,
|
||
token_type_ids: torch.Tensor,
|
||
attention_mask: torch.Tensor = None,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
Args:
|
||
inputs_embeds: [batch_size, seq_len, hidden_dim]
|
||
token_type_ids: [batch_size, seq_len], 0=non-causal, 1=causal
|
||
attention_mask: [batch_size, seq_len], optional
|
||
"""
|
||
return self.model(
|
||
inputs_embeds=inputs_embeds,
|
||
token_type_ids=token_type_ids,
|
||
attention_mask=attention_mask,
|
||
**kwargs,
|
||
)
|
||
|
||
|
||
class Qwen2Decoder2Encoder(nn.Module):
|
||
"""
|
||
Decoder based on Multilingual BART
|
||
Set the initial weights and configuration with a pretrained multilingual BART model,
|
||
and modify the detailed configurations as a Nougat decoder
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
decoder_layer: int,
|
||
hidden_dimension: int,
|
||
num_attention_heads: int,
|
||
num_key_value_heads: int,
|
||
intermediate_size: int,
|
||
):
|
||
super().__init__()
|
||
|
||
self.model = CustomQwen2Decoder(
|
||
decoder_layer=decoder_layer,
|
||
hidden_dimension=hidden_dimension,
|
||
num_attention_heads=num_attention_heads,
|
||
num_key_value_heads=num_key_value_heads,
|
||
intermediate_size=intermediate_size,
|
||
attn_implementation="sdpa",
|
||
)
|
||
self.query_768 = nn.Embedding(144, hidden_dimension)
|
||
self.query_1024 = nn.Embedding(256, hidden_dimension)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
x = x.flatten(2).transpose(1, 2)
|
||
|
||
bs, n_query, _ = x.shape
|
||
|
||
if n_query == 144:
|
||
param_img = self.query_768.weight
|
||
elif n_query == 256:
|
||
param_img = self.query_1024.weight
|
||
|
||
batch_query_imgs = param_img.unsqueeze(0).expand(
|
||
bs, -1, -1
|
||
) # (batch_size, num_queries, hidden_size)
|
||
|
||
x_combined = torch.cat([x, batch_query_imgs], dim=1)
|
||
|
||
token_type_ids = torch.cat(
|
||
[
|
||
torch.zeros(bs, n_query, dtype=torch.long),
|
||
torch.ones(bs, n_query, dtype=torch.long),
|
||
],
|
||
dim=1,
|
||
)
|
||
|
||
y = self.model(x_combined, token_type_ids)[0]
|
||
|
||
y = y[:, n_query:, :] # causal flow query
|
||
|
||
return y
|
||
|
||
|
||
def build_qwen2_decoder_as_encoder(
|
||
decoder_layer=24,
|
||
hidden_dimension=896,
|
||
num_attention_heads=14,
|
||
num_key_value_heads=2,
|
||
intermediate_size=4864,
|
||
):
|
||
decoder_as_encoder = Qwen2Decoder2Encoder(
|
||
decoder_layer=decoder_layer,
|
||
hidden_dimension=hidden_dimension,
|
||
num_attention_heads=num_attention_heads,
|
||
num_key_value_heads=num_key_value_heads,
|
||
intermediate_size=intermediate_size,
|
||
)
|
||
|
||
return decoder_as_encoder
|