[Models]: Use MMEncoderAttention for MoonViT (#31738)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: h100 <h100@inferact.ai>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: h100 <h100@inferact.ai>
This commit is contained in:
Isotr0py
2026-01-06 16:00:25 +08:00
committed by GitHub
parent e9717801bd
commit 7101e0851f
2 changed files with 72 additions and 158 deletions

View File

@@ -325,7 +325,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.hidden_size = config.text_config.hidden_size
self.vision_tower = MoonVitPretrainedModel(
config.vision_config,
self.use_data_parallel,
multimodal_config=model_config.multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)

View File

@@ -51,118 +51,20 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_flash_attn_2_available
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.platforms import current_platform
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
elif current_platform.is_xpu():
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
def multihead_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_cu_seqlens: torch.Tensor | None = None,
k_cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
"""Multi-head attention using flash attention 2.
Args:
q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
The first element should be 0 and the last element should be q.shape[0].
k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
The first element should be 0 and the last element should be k.shape[0].
Returns:
output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
where dim = num_heads * head_dim
"""
# Unified format legal check
assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
assert k_cu_seqlens[-1] == k.shape[0] == v.shape[0], (
"k_cu_seqlens must sum to k.shape[0]"
)
assert q.dtype in [
torch.bfloat16,
torch.float16,
], f"unsupported dtype {q.dtype} for multihead attn"
max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
attn_out = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=q_cu_seqlens,
cu_seqlens_k=k_cu_seqlens,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
causal=False,
)
attn_out = attn_out.flatten(start_dim=-2)
return attn_out
def sdpa_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_cu_seqlens: torch.Tensor | None = None,
k_cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
"""SDPA attention.
Args:
q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
q_cu_seqlens: Optional cumulative sequence lengths of q.
k_cu_seqlens: Optional cumulative sequence lengths of k.
"""
seq_length = q.shape[0]
attention_mask = torch.zeros(
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
)
for i in range(1, len(q_cu_seqlens)):
attention_mask[
...,
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
return attn_output
VL_VISION_ATTENTION_FUNCTIONS = {
"flash_attention_2": multihead_attention,
"sdpa": sdpa_attention,
}
def _apply_rope_input_validation(x, freqs_cis):
assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
@@ -411,11 +313,19 @@ class MLP2(nn.Module):
super().__init__()
assert len(dims) == 3
self.use_data_parallel = use_data_parallel
self.fc0 = ReplicatedLinear(
dims[0], dims[1], bias=bias, prefix=maybe_prefix(prefix, "fc0")
self.fc0 = ColumnParallelLinear(
dims[0],
dims[1],
bias=bias,
prefix=maybe_prefix(prefix, "fc0"),
disable_tp=self.use_data_parallel,
)
self.fc1 = ReplicatedLinear(
dims[1], dims[2], bias=bias, prefix=maybe_prefix(prefix, "fc1")
self.fc1 = RowParallelLinear(
dims[1],
dims[2],
bias=bias,
prefix=maybe_prefix(prefix, "fc1"),
disable_tp=self.use_data_parallel,
)
self.activation = activation
@@ -433,35 +343,55 @@ class MoonVitEncoderLayer(nn.Module):
hidden_dim: int,
mlp_dim: int,
prefix: str = "",
use_data_parallel: bool = False,
multimodal_config: MultiModalConfig | None = None,
*,
attn_implementation: str = "sdpa",
activation=F.gelu,
attn_bias: bool = False,
):
super().__init__()
self.use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
self.attn_implementation = attn_implementation
# use fa2 in vllm by default
if is_flash_attn_2_available() or current_platform.is_xpu():
self.attn_implementation = "flash_attention_2"
self.tp_size = (
1 if self.use_data_parallel else get_tensor_model_parallel_world_size()
)
self.num_attention_heads_per_partition = divide(num_heads, self.tp_size)
self.norm0 = nn.LayerNorm(hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.use_data_parallel = use_data_parallel
self.mlp = MLP2(
[hidden_dim, mlp_dim, hidden_dim],
activation,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
use_data_parallel=self.use_data_parallel,
)
self.wqkv = ReplicatedLinear(
hidden_dim, hidden_dim * 3, bias=attn_bias, prefix=f"{prefix}.wqkv"
self.wqkv = QKVParallelLinear(
hidden_size=hidden_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=attn_bias,
prefix=f"{prefix}.wqkv",
disable_tp=self.use_data_parallel,
)
self.wo = ReplicatedLinear(
hidden_dim, hidden_dim, bias=attn_bias, prefix=f"{prefix}.wo"
self.wo = RowParallelLinear(
hidden_dim,
hidden_dim,
bias=attn_bias,
prefix=f"{prefix}.wo",
disable_tp=self.use_data_parallel,
)
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
def attention_qkvpacked(
@@ -472,14 +402,15 @@ class MoonVitEncoderLayer(nn.Module):
):
"""
Args:
x (torch.Tensor): (batch_size, seqlen, hidden_dim)
x (torch.Tensor): (seqlen, hidden_dim)
cu_seqlens (torch.Tensor):
"""
seq_length = x.size(0)
xqkv, _ = self.wqkv(x)
qkv_shape = xqkv.size()[:-1] + (
3,
self.num_heads,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
# xqkv: (batch_size, seqlen, 3, nheads, headdim)
@@ -488,9 +419,18 @@ class MoonVitEncoderLayer(nn.Module):
xq, xk = apply_rope(xq, xk, rope_freqs_cis)
attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
attn_out = attn_func(
xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_out = self.attn(
xq.unsqueeze(0),
xk.unsqueeze(0),
xv.unsqueeze(0),
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
attn_out = attn_out.reshape(
seq_length,
self.num_attention_heads_per_partition
* self.hidden_size_per_attention_head,
)
attn_out, _ = self.wo(attn_out)
return attn_out
@@ -528,7 +468,7 @@ class MoonVitEncoder(nn.Module):
num_layers: int,
block_cfg: dict,
prefix: str = "",
use_data_parallel: bool = False,
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()
@@ -538,7 +478,7 @@ class MoonVitEncoder(nn.Module):
self.blocks = nn.ModuleList(
[
MoonVitEncoderLayer(
use_data_parallel=use_data_parallel,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
**block_cfg,
)
@@ -599,31 +539,6 @@ def patch_merger(
return outputs
class MoonVitVLProjector(nn.Module):
def __init__(
self,
in_channels: int,
merge_kernel_size: list[int, int],
hidden_act: str = "gelu",
ln_eps: float = 1e-5,
out_dim: int = 4096,
):
super().__init__()
self.hidden_size = in_channels * merge_kernel_size[0] * merge_kernel_size[1]
self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.act = ACT2FN[hidden_act]
self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class MoonVitPretrainedModel(PreTrainedModel):
config_class = MoonViTConfig
model_type = "moonvit"
@@ -634,14 +549,13 @@ class MoonVitPretrainedModel(PreTrainedModel):
def __init__(
self,
config: MoonViTConfig,
use_data_parallel: bool = False,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
*inputs,
**kwargs,
):
super().__init__(config, *inputs, **kwargs)
config = deepcopy(config)
self.use_data_parallel = use_data_parallel
self.merge_kernel_size = config.merge_kernel_size
self.hidden_size = config.hidden_size
self.patch_size = config.patch_size
@@ -662,9 +576,9 @@ class MoonVitPretrainedModel(PreTrainedModel):
"mlp_dim": config.intermediate_size,
"activation": ACT2FN["gelu_pytorch_tanh"],
"attn_bias": True,
"attn_implementation": config._attn_implementation,
},
prefix=f"{prefix}.encoder",
multimodal_config=multimodal_config,
)
def forward(