Use MMEncoderAttention (=use FlashAttention) instead of torch.sdpa in radio.py (#35653)

This commit is contained in:
Netanel Haber
2026-03-04 18:43:13 +02:00
committed by GitHub
parent 2f2212e6cc
commit 289fc48ab7

View File

@@ -10,7 +10,8 @@
import math
from collections.abc import Iterable
from itertools import repeat
from dataclasses import dataclass
from itertools import accumulate, repeat
from typing import TypeAlias
import torch
@@ -477,28 +478,27 @@ class ViTPatchLinear(nn.Linear):
self.patch_size = patch_size
@dataclass(frozen=True, kw_only=True)
class MaskMetadata:
cu_seqlens: torch.Tensor
max_seqlen: torch.Tensor
class RadioParallelAttention(InternParallelAttention):
def forward(
self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
self, x: torch.Tensor, mask_meta: MaskMetadata | None = None
) -> torch.Tensor:
if attn_mask is None:
return super().forward(x)
B, N, _ = x.shape
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
q, k, v = (t.transpose(1, 2) for t in (q, k, v))
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, scale=self.scale
)
out = out.transpose(1, 2).reshape(B, N, -1)
cu_seqlens, max_seqlen = None, None
if mask_meta is not None:
cu_seqlens = mask_meta.cu_seqlens
max_seqlen = mask_meta.max_seqlen
out = self.attn(q, k, v, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
out, _ = self.proj(out)
return out
@@ -510,11 +510,11 @@ class RadioVisionEncoderLayer(InternVisionEncoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
attn_mask: torch.Tensor | None = None,
mask_meta: MaskMetadata | None = None,
):
hidden_states = (
hidden_states
+ self.attn(self.norm1(hidden_states), attn_mask=attn_mask) * self.ls1
+ self.attn(self.norm1(hidden_states), mask_meta=mask_meta) * self.ls1
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2
@@ -529,11 +529,11 @@ class RadioVisionEncoder(InternVisionEncoder):
def forward(
self,
inputs_embeds: torch.Tensor,
attn_mask: torch.Tensor | None = None,
mask_meta: MaskMetadata | None = None,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states, attn_mask=attn_mask)
hidden_states = encoder_layer(hidden_states, mask_meta=mask_meta)
return hidden_states
@@ -590,44 +590,36 @@ class RadioInternVisionModel(nn.Module):
def get_input_embeddings(self):
return self.embeddings
def create_inter_image_attention_mask(
def inter_image_mask_metadata(
self, imgs_sizes: list[tuple[int, int]], device: torch.device
) -> torch.Tensor:
) -> MaskMetadata:
patch_size = self.patch_generator.patch_size
num_skip = self.patch_generator.num_skip
seq_lens = calc_seq_lens(imgs_sizes, patch_size)
patch_counts = [seq_len + num_skip for seq_len in seq_lens]
total_patches = sum(patch_counts)
# Create attention mask - default to False (mask out)
mask = torch.zeros(
total_patches, total_patches, dtype=torch.bool, device=device
adjusted = [s + num_skip for s in seq_lens]
cu_seqlens = torch.tensor(
list(accumulate(adjusted, initial=0)), dtype=torch.int32, device=device
)
# Each image's patches can only attend to patches from the same image
start_idx = 0
for patch_count in patch_counts:
end_idx = start_idx + patch_count
# Allow attention within this image's patches
mask[start_idx:end_idx, start_idx:end_idx] = True
start_idx = end_idx
return mask
# Keep max_seqlen on CPU to avoid .item() sync
# See: https://github.com/vllm-project/vllm/blob/20b6b01/vllm/v1/attention/ops/vit_attn_wrappers.py#L48
max_seqlen = torch.tensor(max(adjusted), dtype=torch.int32)
return MaskMetadata(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
def forward(
self,
x: torch.Tensor,
imgs_sizes: torch.Tensor | None = None,
imgs_sizes: list[tuple[int, int]] | None = None,
) -> torch.FloatTensor:
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
attn_mask = None
if imgs_sizes is not None and len(imgs_sizes) > 1:
# Dynamic Resolution
attn_mask = self.create_inter_image_attention_mask(
imgs_sizes, device=x.device
mask_meta = None
if imgs_sizes is not None:
assert len(imgs_sizes) > 0
# Dynamic resolution: process each image as an independent sequence.
mask_meta = self.inter_image_mask_metadata(
imgs_sizes, device=hidden_states.device
)
encoder_outputs = self.encoder(inputs_embeds=hidden_states, attn_mask=attn_mask)
encoder_outputs = self.encoder(inputs_embeds=hidden_states, mask_meta=mask_meta)
return encoder_outputs
@@ -670,7 +662,7 @@ class RadioModel(nn.Module):
pixel_values: torch.Tensor | None = None,
pixel_embeds: torch.Tensor | None = None,
*,
imgs_sizes: torch.Tensor | None = None,
imgs_sizes: list[tuple[int, int]] | None = None,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
y = self.model(pixel_values, imgs_sizes=imgs_sizes)
return self._extract_final(y, imgs_sizes=imgs_sizes)