[MM Encoder] Add Triton ViT attention backend (#32183)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-02-15 22:32:47 +08:00
committed by GitHub
parent 19fab44152
commit 71cd89264f
14 changed files with 178 additions and 51 deletions

View File

@@ -385,14 +385,6 @@ class Qwen3_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now."
)
self.blocks = nn.ModuleList(
[
Qwen3_VisionBlock(
@@ -526,9 +518,10 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
if self.attn_backend in (
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen