Use MMEncoderAttention (=use FlashAttention) instead of torch.sdpa in radio.py (#35653)
This commit is contained in:
@@ -10,7 +10,8 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from itertools import repeat
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate, repeat
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
@@ -477,28 +478,27 @@ class ViTPatchLinear(nn.Linear):
|
||||
self.patch_size = patch_size
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MaskMetadata:
|
||||
cu_seqlens: torch.Tensor
|
||||
max_seqlen: torch.Tensor
|
||||
|
||||
|
||||
class RadioParallelAttention(InternParallelAttention):
|
||||
def forward(
|
||||
self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
|
||||
self, x: torch.Tensor, mask_meta: MaskMetadata | None = None
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is None:
|
||||
return super().forward(x)
|
||||
|
||||
B, N, _ = x.shape
|
||||
qkv, _ = self.qkv(x)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
if self.qk_normalization:
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
|
||||
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
q, k, v = (t.transpose(1, 2) for t in (q, k, v))
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, scale=self.scale
|
||||
)
|
||||
out = out.transpose(1, 2).reshape(B, N, -1)
|
||||
cu_seqlens, max_seqlen = None, None
|
||||
if mask_meta is not None:
|
||||
cu_seqlens = mask_meta.cu_seqlens
|
||||
max_seqlen = mask_meta.max_seqlen
|
||||
out = self.attn(q, k, v, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
|
||||
out, _ = self.proj(out)
|
||||
return out
|
||||
|
||||
@@ -510,11 +510,11 @@ class RadioVisionEncoderLayer(InternVisionEncoderLayer):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
mask_meta: MaskMetadata | None = None,
|
||||
):
|
||||
hidden_states = (
|
||||
hidden_states
|
||||
+ self.attn(self.norm1(hidden_states), attn_mask=attn_mask) * self.ls1
|
||||
+ self.attn(self.norm1(hidden_states), mask_meta=mask_meta) * self.ls1
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2
|
||||
@@ -529,11 +529,11 @@ class RadioVisionEncoder(InternVisionEncoder):
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
mask_meta: MaskMetadata | None = None,
|
||||
):
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states, attn_mask=attn_mask)
|
||||
hidden_states = encoder_layer(hidden_states, mask_meta=mask_meta)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -590,44 +590,36 @@ class RadioInternVisionModel(nn.Module):
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def create_inter_image_attention_mask(
|
||||
def inter_image_mask_metadata(
|
||||
self, imgs_sizes: list[tuple[int, int]], device: torch.device
|
||||
) -> torch.Tensor:
|
||||
) -> MaskMetadata:
|
||||
patch_size = self.patch_generator.patch_size
|
||||
num_skip = self.patch_generator.num_skip
|
||||
|
||||
seq_lens = calc_seq_lens(imgs_sizes, patch_size)
|
||||
patch_counts = [seq_len + num_skip for seq_len in seq_lens]
|
||||
total_patches = sum(patch_counts)
|
||||
|
||||
# Create attention mask - default to False (mask out)
|
||||
mask = torch.zeros(
|
||||
total_patches, total_patches, dtype=torch.bool, device=device
|
||||
adjusted = [s + num_skip for s in seq_lens]
|
||||
cu_seqlens = torch.tensor(
|
||||
list(accumulate(adjusted, initial=0)), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Each image's patches can only attend to patches from the same image
|
||||
start_idx = 0
|
||||
for patch_count in patch_counts:
|
||||
end_idx = start_idx + patch_count
|
||||
# Allow attention within this image's patches
|
||||
mask[start_idx:end_idx, start_idx:end_idx] = True
|
||||
start_idx = end_idx
|
||||
|
||||
return mask
|
||||
# Keep max_seqlen on CPU to avoid .item() sync
|
||||
# See: https://github.com/vllm-project/vllm/blob/20b6b01/vllm/v1/attention/ops/vit_attn_wrappers.py#L48
|
||||
max_seqlen = torch.tensor(max(adjusted), dtype=torch.int32)
|
||||
return MaskMetadata(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
imgs_sizes: torch.Tensor | None = None,
|
||||
imgs_sizes: list[tuple[int, int]] | None = None,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
|
||||
attn_mask = None
|
||||
if imgs_sizes is not None and len(imgs_sizes) > 1:
|
||||
# Dynamic Resolution
|
||||
attn_mask = self.create_inter_image_attention_mask(
|
||||
imgs_sizes, device=x.device
|
||||
mask_meta = None
|
||||
if imgs_sizes is not None:
|
||||
assert len(imgs_sizes) > 0
|
||||
# Dynamic resolution: process each image as an independent sequence.
|
||||
mask_meta = self.inter_image_mask_metadata(
|
||||
imgs_sizes, device=hidden_states.device
|
||||
)
|
||||
encoder_outputs = self.encoder(inputs_embeds=hidden_states, attn_mask=attn_mask)
|
||||
encoder_outputs = self.encoder(inputs_embeds=hidden_states, mask_meta=mask_meta)
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
@@ -670,7 +662,7 @@ class RadioModel(nn.Module):
|
||||
pixel_values: torch.Tensor | None = None,
|
||||
pixel_embeds: torch.Tensor | None = None,
|
||||
*,
|
||||
imgs_sizes: torch.Tensor | None = None,
|
||||
imgs_sizes: list[tuple[int, int]] | None = None,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
y = self.model(pixel_values, imgs_sizes=imgs_sizes)
|
||||
return self._extract_final(y, imgs_sizes=imgs_sizes)
|
||||
|
||||
Reference in New Issue
Block a user