[Models]: Use MMEncoderAttention for MoonViT (#31738)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: h100 <h100@inferact.ai> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: h100 <h100@inferact.ai>
This commit is contained in:
@@ -325,7 +325,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.hidden_size = config.text_config.hidden_size
|
||||
self.vision_tower = MoonVitPretrainedModel(
|
||||
config.vision_config,
|
||||
self.use_data_parallel,
|
||||
multimodal_config=model_config.multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
|
||||
|
||||
@@ -51,118 +51,20 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
elif current_platform.is_xpu():
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
|
||||
def multihead_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
q_cu_seqlens: torch.Tensor | None = None,
|
||||
k_cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Multi-head attention using flash attention 2.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||
k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||
v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
|
||||
The first element should be 0 and the last element should be q.shape[0].
|
||||
k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
|
||||
The first element should be 0 and the last element should be k.shape[0].
|
||||
|
||||
Returns:
|
||||
output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
|
||||
where dim = num_heads * head_dim
|
||||
"""
|
||||
# Unified format legal check
|
||||
assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
|
||||
assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
|
||||
assert k_cu_seqlens[-1] == k.shape[0] == v.shape[0], (
|
||||
"k_cu_seqlens must sum to k.shape[0]"
|
||||
)
|
||||
assert q.dtype in [
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
], f"unsupported dtype {q.dtype} for multihead attn"
|
||||
|
||||
max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
|
||||
max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
|
||||
attn_out = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=q_cu_seqlens,
|
||||
cu_seqlens_k=k_cu_seqlens,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
causal=False,
|
||||
)
|
||||
attn_out = attn_out.flatten(start_dim=-2)
|
||||
|
||||
return attn_out
|
||||
|
||||
|
||||
def sdpa_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
q_cu_seqlens: torch.Tensor | None = None,
|
||||
k_cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""SDPA attention.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||
k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||
v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||
q_cu_seqlens: Optional cumulative sequence lengths of q.
|
||||
k_cu_seqlens: Optional cumulative sequence lengths of k.
|
||||
"""
|
||||
seq_length = q.shape[0]
|
||||
attention_mask = torch.zeros(
|
||||
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
|
||||
)
|
||||
for i in range(1, len(q_cu_seqlens)):
|
||||
attention_mask[
|
||||
...,
|
||||
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
|
||||
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
|
||||
] = True
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
return attn_output
|
||||
|
||||
|
||||
VL_VISION_ATTENTION_FUNCTIONS = {
|
||||
"flash_attention_2": multihead_attention,
|
||||
"sdpa": sdpa_attention,
|
||||
}
|
||||
|
||||
|
||||
def _apply_rope_input_validation(x, freqs_cis):
|
||||
assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
|
||||
@@ -411,11 +313,19 @@ class MLP2(nn.Module):
|
||||
super().__init__()
|
||||
assert len(dims) == 3
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.fc0 = ReplicatedLinear(
|
||||
dims[0], dims[1], bias=bias, prefix=maybe_prefix(prefix, "fc0")
|
||||
self.fc0 = ColumnParallelLinear(
|
||||
dims[0],
|
||||
dims[1],
|
||||
bias=bias,
|
||||
prefix=maybe_prefix(prefix, "fc0"),
|
||||
disable_tp=self.use_data_parallel,
|
||||
)
|
||||
self.fc1 = ReplicatedLinear(
|
||||
dims[1], dims[2], bias=bias, prefix=maybe_prefix(prefix, "fc1")
|
||||
self.fc1 = RowParallelLinear(
|
||||
dims[1],
|
||||
dims[2],
|
||||
bias=bias,
|
||||
prefix=maybe_prefix(prefix, "fc1"),
|
||||
disable_tp=self.use_data_parallel,
|
||||
)
|
||||
self.activation = activation
|
||||
|
||||
@@ -433,35 +343,55 @@ class MoonVitEncoderLayer(nn.Module):
|
||||
hidden_dim: int,
|
||||
mlp_dim: int,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
attn_implementation: str = "sdpa",
|
||||
activation=F.gelu,
|
||||
attn_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.hidden_dim = hidden_dim
|
||||
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
|
||||
self.attn_implementation = attn_implementation
|
||||
# use fa2 in vllm by default
|
||||
if is_flash_attn_2_available() or current_platform.is_xpu():
|
||||
self.attn_implementation = "flash_attention_2"
|
||||
self.tp_size = (
|
||||
1 if self.use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
self.num_attention_heads_per_partition = divide(num_heads, self.tp_size)
|
||||
|
||||
self.norm0 = nn.LayerNorm(hidden_dim)
|
||||
self.norm1 = nn.LayerNorm(hidden_dim)
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.mlp = MLP2(
|
||||
[hidden_dim, mlp_dim, hidden_dim],
|
||||
activation,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.wqkv = ReplicatedLinear(
|
||||
hidden_dim, hidden_dim * 3, bias=attn_bias, prefix=f"{prefix}.wqkv"
|
||||
self.wqkv = QKVParallelLinear(
|
||||
hidden_size=hidden_dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=attn_bias,
|
||||
prefix=f"{prefix}.wqkv",
|
||||
disable_tp=self.use_data_parallel,
|
||||
)
|
||||
self.wo = ReplicatedLinear(
|
||||
hidden_dim, hidden_dim, bias=attn_bias, prefix=f"{prefix}.wo"
|
||||
self.wo = RowParallelLinear(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
bias=attn_bias,
|
||||
prefix=f"{prefix}.wo",
|
||||
disable_tp=self.use_data_parallel,
|
||||
)
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def attention_qkvpacked(
|
||||
@@ -472,14 +402,15 @@ class MoonVitEncoderLayer(nn.Module):
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): (batch_size, seqlen, hidden_dim)
|
||||
x (torch.Tensor): (seqlen, hidden_dim)
|
||||
cu_seqlens (torch.Tensor):
|
||||
"""
|
||||
seq_length = x.size(0)
|
||||
xqkv, _ = self.wqkv(x)
|
||||
|
||||
qkv_shape = xqkv.size()[:-1] + (
|
||||
3,
|
||||
self.num_heads,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
# xqkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||
@@ -488,9 +419,18 @@ class MoonVitEncoderLayer(nn.Module):
|
||||
|
||||
xq, xk = apply_rope(xq, xk, rope_freqs_cis)
|
||||
|
||||
attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
|
||||
attn_out = attn_func(
|
||||
xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
attn_out = self.attn(
|
||||
xq.unsqueeze(0),
|
||||
xk.unsqueeze(0),
|
||||
xv.unsqueeze(0),
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
attn_out = attn_out.reshape(
|
||||
seq_length,
|
||||
self.num_attention_heads_per_partition
|
||||
* self.hidden_size_per_attention_head,
|
||||
)
|
||||
attn_out, _ = self.wo(attn_out)
|
||||
return attn_out
|
||||
@@ -528,7 +468,7 @@ class MoonVitEncoder(nn.Module):
|
||||
num_layers: int,
|
||||
block_cfg: dict,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -538,7 +478,7 @@ class MoonVitEncoder(nn.Module):
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
MoonVitEncoderLayer(
|
||||
use_data_parallel=use_data_parallel,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
**block_cfg,
|
||||
)
|
||||
@@ -599,31 +539,6 @@ def patch_merger(
|
||||
return outputs
|
||||
|
||||
|
||||
class MoonVitVLProjector(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
merge_kernel_size: list[int, int],
|
||||
hidden_act: str = "gelu",
|
||||
ln_eps: float = 1e-5,
|
||||
out_dim: int = 4096,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = in_channels * merge_kernel_size[0] * merge_kernel_size[1]
|
||||
|
||||
self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
|
||||
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
|
||||
self.act = ACT2FN[hidden_act]
|
||||
self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MoonVitPretrainedModel(PreTrainedModel):
|
||||
config_class = MoonViTConfig
|
||||
model_type = "moonvit"
|
||||
@@ -634,14 +549,13 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
||||
def __init__(
|
||||
self,
|
||||
config: MoonViTConfig,
|
||||
use_data_parallel: bool = False,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
*inputs,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
config = deepcopy(config)
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.merge_kernel_size = config.merge_kernel_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.patch_size = config.patch_size
|
||||
@@ -662,9 +576,9 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
||||
"mlp_dim": config.intermediate_size,
|
||||
"activation": ACT2FN["gelu_pytorch_tanh"],
|
||||
"attn_bias": True,
|
||||
"attn_implementation": config._attn_implementation,
|
||||
},
|
||||
prefix=f"{prefix}.encoder",
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user