[MM Encoder] Add Triton ViT attention backend (#32183)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -573,10 +573,11 @@ class DotsVisionTransformer(nn.Module):
|
||||
|
||||
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
|
||||
max_seqlen = None
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
|
||||
@@ -446,10 +446,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
|
||||
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
|
||||
max_seqlen = None
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
|
||||
@@ -723,10 +723,11 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
max_seqlen = None
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
|
||||
@@ -730,14 +730,7 @@ class SiglipEncoder(nn.Module):
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
)
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
SiglipEncoderLayer(
|
||||
@@ -805,6 +798,7 @@ class SiglipEncoder(nn.Module):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
|
||||
|
||||
@@ -607,15 +607,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
with set_model_tag("Qwen2_5_VisionBlock", is_encoder=True):
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
@@ -761,6 +752,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
@@ -642,6 +642,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
@@ -391,6 +391,7 @@ class Qwen3OmniMoeAudioEncoder(nn.Module):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
@@ -919,6 +920,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
@@ -385,14 +385,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen3-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Qwen3_VisionBlock(
|
||||
@@ -526,9 +518,10 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
if self.attn_backend in (
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
@@ -108,7 +108,7 @@ def get_vit_attn_backend(
|
||||
multimodal_config: MultiModalConfig | None = (
|
||||
model_config.multimodal_config if model_config is not None else None
|
||||
)
|
||||
except AssertionError:
|
||||
except (AssertionError, AttributeError):
|
||||
multimodal_config = None
|
||||
|
||||
attn_backend_override = (
|
||||
@@ -134,7 +134,7 @@ def is_vit_use_data_parallel():
|
||||
multimodal_config: MultiModalConfig | None = (
|
||||
model_config.multimodal_config if model_config is not None else None
|
||||
)
|
||||
except AssertionError:
|
||||
except (AssertionError, AttributeError):
|
||||
multimodal_config = None
|
||||
|
||||
mm_encoder_tp_mode = (
|
||||
|
||||
Reference in New Issue
Block a user