diff --git a/vllm/model_executor/models/radio.py b/vllm/model_executor/models/radio.py index c6dc05cbd..5fa71d7f2 100644 --- a/vllm/model_executor/models/radio.py +++ b/vllm/model_executor/models/radio.py @@ -10,7 +10,8 @@ import math from collections.abc import Iterable -from itertools import repeat +from dataclasses import dataclass +from itertools import accumulate, repeat from typing import TypeAlias import torch @@ -477,28 +478,27 @@ class ViTPatchLinear(nn.Linear): self.patch_size = patch_size +@dataclass(frozen=True, kw_only=True) +class MaskMetadata: + cu_seqlens: torch.Tensor + max_seqlen: torch.Tensor + + class RadioParallelAttention(InternParallelAttention): def forward( - self, x: torch.Tensor, attn_mask: torch.Tensor | None = None + self, x: torch.Tensor, mask_meta: MaskMetadata | None = None ) -> torch.Tensor: - if attn_mask is None: - return super().forward(x) - - B, N, _ = x.shape qkv, _ = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) if self.qk_normalization: q, k = self._apply_qk_norm(q, k) - q = q.view(B, N, self.num_heads_per_partition, self.head_dim) - k = k.view(B, N, self.num_heads_per_partition, self.head_dim) - v = v.view(B, N, self.num_heads_per_partition, self.head_dim) - q, k, v = (t.transpose(1, 2) for t in (q, k, v)) - out = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, scale=self.scale - ) - out = out.transpose(1, 2).reshape(B, N, -1) + cu_seqlens, max_seqlen = None, None + if mask_meta is not None: + cu_seqlens = mask_meta.cu_seqlens + max_seqlen = mask_meta.max_seqlen + out = self.attn(q, k, v, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) out, _ = self.proj(out) return out @@ -510,11 +510,11 @@ class RadioVisionEncoderLayer(InternVisionEncoderLayer): def forward( self, hidden_states: torch.Tensor, - attn_mask: torch.Tensor | None = None, + mask_meta: MaskMetadata | None = None, ): hidden_states = ( hidden_states - + self.attn(self.norm1(hidden_states), attn_mask=attn_mask) * self.ls1 + + self.attn(self.norm1(hidden_states), mask_meta=mask_meta) * self.ls1 ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2 @@ -529,11 +529,11 @@ class RadioVisionEncoder(InternVisionEncoder): def forward( self, inputs_embeds: torch.Tensor, - attn_mask: torch.Tensor | None = None, + mask_meta: MaskMetadata | None = None, ): hidden_states = inputs_embeds for encoder_layer in self.layers: - hidden_states = encoder_layer(hidden_states, attn_mask=attn_mask) + hidden_states = encoder_layer(hidden_states, mask_meta=mask_meta) return hidden_states @@ -590,44 +590,36 @@ class RadioInternVisionModel(nn.Module): def get_input_embeddings(self): return self.embeddings - def create_inter_image_attention_mask( + def inter_image_mask_metadata( self, imgs_sizes: list[tuple[int, int]], device: torch.device - ) -> torch.Tensor: + ) -> MaskMetadata: patch_size = self.patch_generator.patch_size num_skip = self.patch_generator.num_skip seq_lens = calc_seq_lens(imgs_sizes, patch_size) - patch_counts = [seq_len + num_skip for seq_len in seq_lens] - total_patches = sum(patch_counts) - - # Create attention mask - default to False (mask out) - mask = torch.zeros( - total_patches, total_patches, dtype=torch.bool, device=device + adjusted = [s + num_skip for s in seq_lens] + cu_seqlens = torch.tensor( + list(accumulate(adjusted, initial=0)), dtype=torch.int32, device=device ) - - # Each image's patches can only attend to patches from the same image - start_idx = 0 - for patch_count in patch_counts: - end_idx = start_idx + patch_count - # Allow attention within this image's patches - mask[start_idx:end_idx, start_idx:end_idx] = True - start_idx = end_idx - - return mask + # Keep max_seqlen on CPU to avoid .item() sync + # See: https://github.com/vllm-project/vllm/blob/20b6b01/vllm/v1/attention/ops/vit_attn_wrappers.py#L48 + max_seqlen = torch.tensor(max(adjusted), dtype=torch.int32) + return MaskMetadata(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) def forward( self, x: torch.Tensor, - imgs_sizes: torch.Tensor | None = None, + imgs_sizes: list[tuple[int, int]] | None = None, ) -> torch.FloatTensor: hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes) - attn_mask = None - if imgs_sizes is not None and len(imgs_sizes) > 1: - # Dynamic Resolution - attn_mask = self.create_inter_image_attention_mask( - imgs_sizes, device=x.device + mask_meta = None + if imgs_sizes is not None: + assert len(imgs_sizes) > 0 + # Dynamic resolution: process each image as an independent sequence. + mask_meta = self.inter_image_mask_metadata( + imgs_sizes, device=hidden_states.device ) - encoder_outputs = self.encoder(inputs_embeds=hidden_states, attn_mask=attn_mask) + encoder_outputs = self.encoder(inputs_embeds=hidden_states, mask_meta=mask_meta) return encoder_outputs @@ -670,7 +662,7 @@ class RadioModel(nn.Module): pixel_values: torch.Tensor | None = None, pixel_embeds: torch.Tensor | None = None, *, - imgs_sizes: torch.Tensor | None = None, + imgs_sizes: list[tuple[int, int]] | None = None, ) -> tuple[torch.FloatTensor, torch.FloatTensor]: y = self.model(pixel_values, imgs_sizes=imgs_sizes) return self._extract_final(y, imgs_sizes=imgs_sizes)