[CustomOp][MM] Extract MMEncoderAttention as CustomOp and replace the backend of QwenVisionAttention with it. (#30125)

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
Shanshan Shen
2025-12-15 11:13:32 +08:00
committed by GitHub
parent 84e23d103d
commit 87b4d1557d
24 changed files with 1262 additions and 851 deletions

View File

@@ -47,8 +47,10 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.config import VllmConfig
from vllm.attention.layers.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
from vllm.distributed import utils as dist_utils
@@ -191,10 +193,15 @@ class Glm4vVisionMLP(nn.Module):
hidden_features: int,
bias: bool = False,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2,
@@ -248,12 +255,16 @@ class Glm4vVisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
@@ -287,34 +298,12 @@ class Glm4vVisionAttention(nn.Module):
disable_tp=use_data_parallel,
)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
multimodal_config=multimodal_config,
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"GLM-4V does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@@ -338,14 +327,13 @@ class Glm4vVisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
@@ -356,43 +344,14 @@ class Glm4vVisionAttention(nn.Module):
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = self.flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
output, _ = self.proj(context_layer)
return output
@@ -406,9 +365,8 @@ class Glm4vVisionBlock(nn.Module):
mlp_hidden_dim: int,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
if norm_layer is None:
@@ -420,17 +378,16 @@ class Glm4vVisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.mlp = Glm4vVisionMLP(
dim,
mlp_hidden_dim,
bias=False,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(
@@ -489,11 +446,16 @@ class Glm4vPatchMerger(nn.Module):
d_model: int,
context_dim: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
bias: bool = False,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = d_model
self.proj = ColumnParallelLinear(
self.hidden_size,
@@ -649,19 +611,19 @@ class Glm4vVisionTransformer(nn.Module):
vision_config: Glm4vVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
assert multimodal_config is not None, "multimodal_config must be provided"
patch_size = vision_config.patch_size
temporal_patch_size = vision_config.temporal_patch_size
in_channels = vision_config.in_channels
depth = vision_config.depth
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.use_data_parallel = use_data_parallel
self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size
@@ -690,9 +652,8 @@ class Glm4vVisionTransformer(nn.Module):
mlp_hidden_dim=vision_config.out_hidden_size,
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
@@ -701,9 +662,9 @@ class Glm4vVisionTransformer(nn.Module):
d_model=vision_config.out_hidden_size,
context_dim=vision_config.intermediate_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
bias=False,
prefix=f"{prefix}.merger",
use_data_parallel=self.use_data_parallel,
)
self.embeddings = Glm4vVisionEmbeddings(vision_config)
@@ -723,7 +684,7 @@ class Glm4vVisionTransformer(nn.Module):
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
attn_backend_override=multimodal_config.mm_encoder_attn_backend,
)
@property
@@ -775,13 +736,13 @@ class Glm4vVisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> int | None:
) -> torch.Tensor | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen
def forward(
@@ -1465,18 +1426,12 @@ class Glm4vForConditionalGeneration(
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Glm4vVisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
if config.model_type == "glm4v":