[CustomOp][MM] Extract MMEncoderAttention as CustomOp and replace the backend of QwenVisionAttention with it. (#30125)
Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -13,7 +13,8 @@ from transformers import Siglip2VisionConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
@@ -28,8 +29,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .vision import get_vit_attn_backend
|
||||
|
||||
|
||||
class VisionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||
@@ -190,7 +189,7 @@ def apply_rotary_pos_emb(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
if is_flash_attn_backend and not current_platform.is_xpu():
|
||||
if is_flash_attn_backend and current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
apply_rotary_emb_func = apply_rotary_emb
|
||||
@@ -208,6 +207,7 @@ class Siglip2Attention(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
@@ -227,20 +227,25 @@ class Siglip2Attention(nn.Module):
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
# TODO(Isotr0py): Enable data parallel after we support
|
||||
# disabling TP on parallel linear layer
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=self.embed_dim,
|
||||
output_size=self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.tp_size = (
|
||||
@@ -249,31 +254,13 @@ class Siglip2Attention(nn.Module):
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
self.use_rope = config.use_rope
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_heads_per_partition,
|
||||
head_size=self.head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
prefix=f"{prefix}.attn",
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -298,46 +285,23 @@ class Siglip2Attention(nn.Module):
|
||||
keys.unsqueeze(0),
|
||||
cos,
|
||||
sin,
|
||||
self.is_flash_attn_backend,
|
||||
self.attn.is_flash_attn_backend,
|
||||
)
|
||||
queries = queries.squeeze(0)
|
||||
keys = keys.squeeze(0)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
if self.is_flash_attn_backend:
|
||||
attn_output = self.flash_attn_varlen_func(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
).reshape(seq_length, -1)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
batch_size = cu_seqlens.shape[0] - 1
|
||||
outputs = []
|
||||
cu = cu_seqlens.tolist()
|
||||
for i in range(batch_size):
|
||||
start_idx = cu[i]
|
||||
end_idx = cu[i + 1]
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
attn_output = self.attn(
|
||||
query=queries.unsqueeze(0),
|
||||
key=keys.unsqueeze(0),
|
||||
value=values.unsqueeze(0),
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
attn_output = attn_output.reshape(
|
||||
seq_length, self.num_heads_per_partition * self.head_dim
|
||||
)
|
||||
|
||||
# Each sequence is processed independently.
|
||||
q_i = queries[start_idx:end_idx].unsqueeze(0)
|
||||
k_i = keys[start_idx:end_idx].unsqueeze(0)
|
||||
v_i = values[start_idx:end_idx].unsqueeze(0)
|
||||
|
||||
# (1, seq_len, num_heads, head_dim) ->
|
||||
# (1, num_heads, seq_len, head_dim)
|
||||
q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]
|
||||
|
||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
|
||||
output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1)
|
||||
outputs.append(output_i)
|
||||
|
||||
attn_output = torch.cat(outputs, dim=0)
|
||||
attn_output, _ = self.out_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@@ -347,25 +311,30 @@ class Siglip2MLP(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
# TODO(Isotr0py): Enable data parallel after we support
|
||||
# disabling TP on parallel linear layer
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -380,9 +349,8 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@@ -390,16 +358,15 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
self.self_attn = Siglip2Attention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = Siglip2MLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -444,9 +411,8 @@ class Siglip2Encoder(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -455,9 +421,8 @@ class Siglip2Encoder(nn.Module):
|
||||
Siglip2EncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -630,9 +595,8 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -642,9 +606,8 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
self.encoder = Siglip2Encoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@@ -671,18 +634,16 @@ class Siglip2NavitModel(torch.nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vision_model = Siglip2VisionTransformer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user