diff --git a/vllm/model_executor/models/siglip2.py b/vllm/model_executor/models/lfm2_siglip2.py similarity index 78% rename from vllm/model_executor/models/siglip2.py rename to vllm/model_executor/models/lfm2_siglip2.py index 8fbc408ec..439dba5da 100644 --- a/vllm/model_executor/models/siglip2.py +++ b/vllm/model_executor/models/lfm2_siglip2.py @@ -40,99 +40,111 @@ class Siglip2VisionEmbeddings(nn.Module): self.position_embedding_size = int(self.num_patches**0.5) self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) - @staticmethod - def resize_positional_embeddings( - positional_embeddings: torch.Tensor, + def forward( + self, + pixel_values_packed: torch.FloatTensor, spatial_shapes: torch.LongTensor, - max_length: int, ) -> torch.Tensor: - """ - Resize positional embeddings to image-specific size and pad to a fixed size. + """Embed patchified pixel values in packed (unpadded) form. Args: - positional_embeddings (`torch.Tensor`): - Position embeddings of shape (height, width, embed_dim) - spatial_shapes (`torch.LongTensor`): - Spatial shapes of shape (batch_size, 2) to resize the positional - embeddings to - max_length (`int`): - Maximum length of the positional embeddings to pad resized - positional embeddings to + pixel_values_packed: (1, total_tokens, patch_dim) or + (total_tokens, patch_dim), packed in tile order. + spatial_shapes: (num_tiles, 2) on CPU (height, width) per tile. Returns: - `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) + (1, total_tokens, embed_dim) packed embeddings. """ - batch_size = spatial_shapes.shape[0] + assert spatial_shapes.device.type == "cpu", ( + "Expected `spatial_shapes` on CPU to avoid device-to-host sync in " + "variable-length packing." + ) + + if pixel_values_packed.dim() == 3: + assert pixel_values_packed.shape[0] == 1 + pixel_values_flat = pixel_values_packed[0] + else: + pixel_values_flat = pixel_values_packed + + lengths = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).to(dtype=torch.int64) + lengths_list = lengths.tolist() + total_tokens = int(sum(lengths_list)) + if total_tokens != pixel_values_flat.shape[0]: + raise ValueError( + "Packed pixel_values token count does not match spatial_shapes: " + f"{pixel_values_flat.shape[0]} vs {total_tokens}." + ) + + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values_flat.to(dtype=target_dtype)) + + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, -1 + ) + packed_pos_embeds = self.resize_positional_embeddings_packed( + positional_embeddings, + spatial_shapes, + lengths_list=lengths_list, + ) + + embeddings = patch_embeds + packed_pos_embeds + return embeddings.unsqueeze(0) + + @staticmethod + def resize_positional_embeddings_packed( + positional_embeddings: torch.Tensor, + spatial_shapes: torch.LongTensor, + lengths_list: list[int], + ) -> torch.Tensor: + """Resize positional embeddings per image and return a packed tensor. + + Args: + positional_embeddings: (height, width, embed_dim) base grid. + spatial_shapes: (batch_size, 2) on CPU, (height, width) per image. + lengths_list: flattened token length per image (height * width). + + Returns: + (total_tokens, embed_dim) packed positional embeddings, concatenated + in the same order as `lengths_list`. + """ + assert spatial_shapes.device.type == "cpu" + embed_dim = positional_embeddings.shape[-1] source_dtype = positional_embeddings.dtype - resulted_positional_embeddings = torch.empty( - (batch_size, max_length, embed_dim), + total_tokens = int(sum(lengths_list)) + packed_pos_embeds = torch.empty( + (total_tokens, embed_dim), device=positional_embeddings.device, dtype=source_dtype, ) - # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation - positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) + # (height, width, embed_dim) -> (1, embed_dim, height, width) + pos_4d = positional_embeddings.permute(2, 0, 1).unsqueeze(0) # Upcast to float32 on CPU because antialias is not supported for - # bfloat16/float16 on CPU - if positional_embeddings.device.type == "cpu": - positional_embeddings = positional_embeddings.to(torch.float32) + # bfloat16/float16 on CPU. + if pos_4d.device.type == "cpu": + pos_4d = pos_4d.to(torch.float32) - for i in range(batch_size): - # (1, dim, height, width) -> (1, dim, target_height, target_width) - height, width = spatial_shapes[i] - resized_embeddings = F.interpolate( - positional_embeddings, + offset = 0 + for i, length in enumerate(lengths_list): + if length <= 0: + continue + height, width = spatial_shapes[i].tolist() + resized = F.interpolate( + pos_4d, size=(height, width), mode="bilinear", align_corners=False, antialias=True, ) + resized = resized.reshape(embed_dim, height * width).transpose(0, 1) + resized = resized.to(source_dtype) + packed_pos_embeds[offset : offset + length] = resized + offset += length - # (1, dim, target_height, target_width) -> - # (target_height * target_width, dim) - resized_embeddings = resized_embeddings.reshape( - embed_dim, height * width - ).transpose(0, 1) - - # Cast to original dtype - resized_embeddings = resized_embeddings.to(source_dtype) - - resulted_positional_embeddings[i, : height * width] = resized_embeddings - resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] - - return resulted_positional_embeddings - - def forward( - self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor - ) -> torch.Tensor: - """ - Args: - pixel_values (`torch.FloatTensor`): - Pixel values of shape (batch_size, max_num_patches, - num_channels * patch_size * patch_size) - spatial_shapes (`list[tuple[int, int]]`): - Spatial shapes of shape (batch_size, 2) to resize the positional - embeddings to - """ - - # Apply patch embeddings to already patchified pixel values - target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) - - # Get positional resized and padded positional embeddings - positional_embeddings = self.position_embedding.weight.reshape( - self.position_embedding_size, self.position_embedding_size, -1 - ) - resized_positional_embeddings = self.resize_positional_embeddings( - positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1] - ) - - # Add positional embeddings to patch embeddings - embeddings = patch_embeds + resized_positional_embeddings - return embeddings + return packed_pos_embeds class Siglip2Attention(nn.Module): @@ -402,36 +414,23 @@ class Siglip2VisionTransformer(nn.Module): def forward( self, - pixel_values: torch.FloatTensor, + pixel_values_packed: torch.FloatTensor, spatial_shapes: torch.LongTensor, - packed_mask: torch.Tensor, cu_seqlens: torch.Tensor, - max_seqlen: int | torch.Tensor, + max_seqlen: torch.Tensor, ) -> torch.Tensor: r""" spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): Tensor containing the spatial dimensions (height, width) - of the input images. + of the input images. """ - hidden_states = self.embeddings(pixel_values, spatial_shapes) - flat_mask = packed_mask.view(-1) - packed_indices = flat_mask.nonzero(as_tuple=True)[0] - flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - hidden_states = flat_hidden_states.index_select(0, packed_indices).unsqueeze(0) + hidden_states = self.embeddings(pixel_values_packed, spatial_shapes) encoder_outputs = self.encoder( inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) - unpacked = encoder_outputs.new_zeros( - packed_mask.numel(), encoder_outputs.shape[-1] - ) - unpacked.index_copy_(0, packed_indices, encoder_outputs.squeeze(0)) - encoder_outputs = unpacked.view( - packed_mask.shape + (encoder_outputs.shape[-1],) - ) - last_hidden_state = self.post_layernorm(encoder_outputs) - return last_hidden_state + return self.post_layernorm(encoder_outputs) class Siglip2Model(torch.nn.Module): @@ -453,16 +452,14 @@ class Siglip2Model(torch.nn.Module): def forward( self, - pixel_values: torch.FloatTensor, + pixel_values_packed: torch.FloatTensor, spatial_shapes: torch.LongTensor, - packed_mask: torch.Tensor, cu_seqlens: torch.Tensor, - max_seqlen: int | torch.Tensor, + max_seqlen: torch.Tensor, ) -> torch.Tensor: return self.vision_model( - pixel_values=pixel_values, + pixel_values_packed=pixel_values_packed, spatial_shapes=spatial_shapes, - packed_mask=packed_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) diff --git a/vllm/model_executor/models/lfm2_vl.py b/vllm/model_executor/models/lfm2_vl.py index 1d0abc948..f70675171 100644 --- a/vllm/model_executor/models/lfm2_vl.py +++ b/vllm/model_executor/models/lfm2_vl.py @@ -50,7 +50,7 @@ from .interfaces import ( SupportsMultiModal, SupportsPP, ) -from .siglip2 import Siglip2Model +from .lfm2_siglip2 import Siglip2Model from .utils import ( AutoWeightsLoader, WeightsMapper, @@ -450,29 +450,78 @@ class Lfm2VLMultiModalProjector(nn.Module): bias=config.projector_bias, ) - def forward(self, image_features: torch.Tensor): - image_features = self.pixel_unshuffle(image_features) - if self.projector_use_layernorm: - image_features = self.layer_norm(image_features) - hidden_states = self.linear_1(image_features) - hidden_states = self.act(hidden_states) - hidden_states = self.linear_2(hidden_states) - return hidden_states + def forward( + self, + vision_features_packed: torch.Tensor, + spatial_shapes: torch.Tensor, + ) -> torch.Tensor: + """Project packed vision features without materializing padded tensors. - def pixel_unshuffle(self, hidden_states: torch.Tensor): - batch_size, width, height, channels = hidden_states.size() - hidden_states = hidden_states.reshape( - batch_size, width, height // self.factor, channels * self.factor + Args: + vision_features_packed: (total_tokens, hidden_size) packed in tile order. + spatial_shapes: (num_tiles, 2) on CPU (height, width) per tile. + + Returns: + projected_packed: (total_projected_tokens, text_hidden_size) + """ + assert spatial_shapes.device.type == "cpu", ( + "Expected `spatial_shapes` on CPU to avoid device-to-host sync in " + "variable-length packing." ) - hidden_states = hidden_states.permute(0, 2, 1, 3) - hidden_states = hidden_states.reshape( - batch_size, - height // self.factor, - width // self.factor, - channels * self.factor**2, - ) - hidden_states = hidden_states.permute(0, 2, 1, 3) - return hidden_states + factor = self.factor + device = vision_features_packed.device + hidden_size = vision_features_packed.shape[-1] + + spatial_shapes_list: list[list[int]] = spatial_shapes.tolist() + lengths_list = [h * w for h, w in spatial_shapes_list] + + gather_idx_parts: list[torch.Tensor] = [] + offset = 0 + + dh = torch.arange(factor, dtype=torch.int64) + dw = torch.arange(factor, dtype=torch.int64) + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") + dh_flat = dh_grid.reshape(-1) + dw_flat = dw_grid.reshape(-1) + + for (height, width), length in zip(spatial_shapes_list, lengths_list): + if length <= 0: + continue + if height % factor != 0 or width % factor != 0: + raise ValueError( + "spatial_shapes must be divisible by downsample_factor: " + f"got ({height}, {width}) with factor={factor}." + ) + height_out = height // factor + width_out = width // factor + + rows_out = torch.arange(height_out, dtype=torch.int64) + cols_out = torch.arange(width_out, dtype=torch.int64) + rr, cc = torch.meshgrid(rows_out, cols_out, indexing="ij") + rr = rr.reshape(-1) + cc = cc.reshape(-1) + + token_idx = (rr[:, None] * factor + dh_flat[None, :]) * width + ( + cc[:, None] * factor + dw_flat[None, :] + ) + gather_idx_parts.append(token_idx.reshape(-1) + offset) + offset += length + + if gather_idx_parts: + gather_idx = torch.cat(gather_idx_parts).to(device=device) + gathered = vision_features_packed.index_select(0, gather_idx) + unshuffled = gathered.reshape(-1, factor * factor * hidden_size) + else: + unshuffled = vision_features_packed.new_empty( + (0, factor * factor * hidden_size) + ) + + if self.projector_use_layernorm: + unshuffled = self.layer_norm(unshuffled) + hidden_states = self.linear_1(unshuffled) + hidden_states = self.act(hidden_states) + projected_packed = self.linear_2(hidden_states) + return projected_packed @MULTIMODAL_REGISTRY.register_processor( @@ -598,61 +647,90 @@ class Lfm2VLForConditionalGeneration( pixel_values: torch.FloatTensor, spatial_shapes: torch.Tensor, ) -> torch.Tensor: + assert spatial_shapes.device.type == "cpu", ( + "Expected `spatial_shapes` on CPU to avoid device-to-host sync in " + "variable-length packing." + ) + pixel_values = pixel_values.to( dtype=self.vision_tower.vision_model.embeddings.patch_embedding.weight.dtype ) # fp16 compatibility # LFM2-VL's HF processor pads patch sequences with trailing zeros. - # Derive the valid-patch mask from spatial_shapes instead of carrying - # pixel_attention_mask through the vLLM multimodal pipeline. - max_seq_len = pixel_values.shape[1] + # Pack patch tokens upfront so the vision tower runs entirely unpadded. + spatial_shapes_list: list[list[int]] = spatial_shapes.tolist() + lengths_list = [h * w for h, w in spatial_shapes_list] + total_tokens = int(sum(lengths_list)) lengths_cpu = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).to( dtype=torch.int32 ) max_seqlen = ( - lengths_cpu.max().reshape(1).to(device=pixel_values.device) + lengths_cpu.max().reshape(1) if lengths_cpu.numel() - else torch.tensor([0], dtype=torch.int32, device=pixel_values.device) + else torch.tensor([0], dtype=torch.int32) ) - lengths = lengths_cpu.to(device=pixel_values.device) - packed_mask = ( - torch.arange(max_seq_len, device=pixel_values.device)[None, :] - < lengths[:, None] + + if total_tokens == 0: + return [] + + packed_pixel_values = pixel_values.new_empty( + (total_tokens, pixel_values.shape[-1]) + ) + offset = 0 + for i, length in enumerate(lengths_list): + if length <= 0: + continue + packed_pixel_values[offset : offset + length].copy_( + pixel_values[i, :length] + ) + offset += length + packed_pixel_values = packed_pixel_values.unsqueeze(0) + + lengths = torch.tensor( + lengths_list, dtype=torch.int32, device=pixel_values.device ) cu_seqlens = torch.zeros( lengths.shape[0] + 1, dtype=torch.int32, - device=lengths.device, + device=pixel_values.device, ) cu_seqlens[1:] = torch.cumsum(lengths, dim=0) with set_forward_context(None, self.vllm_config): vision_outputs = self.vision_tower( - pixel_values=pixel_values, + pixel_values_packed=packed_pixel_values, spatial_shapes=spatial_shapes, - packed_mask=packed_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) - image_outputs = getattr(vision_outputs, "last_hidden_state", vision_outputs) + image_outputs_packed = getattr( + vision_outputs, "last_hidden_state", vision_outputs + ) + vision_features_packed = image_outputs_packed[0] - image_features = [] + factor = self.multi_modal_projector.factor + projected_lengths_list: list[int] = [] + for (height, width), length in zip(spatial_shapes_list, lengths_list): + if length <= 0: + projected_lengths_list.append(0) + continue + if height % factor != 0 or width % factor != 0: + raise ValueError( + "spatial_shapes must be divisible by downsample_factor: " + f"got ({height}, {width}) with factor={factor}." + ) + projected_lengths_list.append((height // factor) * (width // factor)) - # spatial_shapes is on CPU (keep_on_cpu=True), so .tolist() is instant - spatial_shapes_list = spatial_shapes.tolist() - for img_idx, (feature_org_h, feature_org_w) in enumerate(spatial_shapes_list): - feature_len = feature_org_h * feature_org_w - feature = image_outputs[img_idx, :feature_len] + projected_packed = self.multi_modal_projector( + vision_features_packed=vision_features_packed, + spatial_shapes=spatial_shapes, + ) - # reshape to original height and width - feature = feature.reshape(1, feature_org_h, feature_org_w, -1) - - # project the image representation - img_embedding = self.multi_modal_projector(feature) - - # flatten here to handle variable length in naflex - img_embedding = img_embedding.reshape(-1, img_embedding.size(-1)) - image_features.append(img_embedding) + image_features: list[torch.Tensor] = [] + offset = 0 + for out_len in projected_lengths_list: + image_features.append(projected_packed[offset : offset + out_len]) + offset += out_len return image_features diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index cc7e7844d..be3825e19 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -155,9 +155,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] m = common_attn_metadata query_start_loc = m.query_start_loc + query_start_loc_cpu = m.query_start_loc_cpu context_lens_tensor = m.compute_num_computed_tokens() nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + spec_sequence_masks_cpu: torch.Tensor | None = None if ( not self.use_spec_decode or num_decode_draft_tokens_cpu is None @@ -169,12 +171,13 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] spec_sequence_masks = None num_spec_decodes = 0 else: - spec_sequence_masks = num_decode_draft_tokens_cpu >= 0 - num_spec_decodes = spec_sequence_masks.sum().item() + spec_sequence_masks_cpu = num_decode_draft_tokens_cpu >= 0 + num_spec_decodes = spec_sequence_masks_cpu.sum().item() if num_spec_decodes == 0: spec_sequence_masks = None + spec_sequence_masks_cpu = None else: - spec_sequence_masks = spec_sequence_masks.to( + spec_sequence_masks = spec_sequence_masks_cpu.to( query_start_loc.device, non_blocking=True ) @@ -189,9 +192,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] non_spec_state_indices_tensor = m.block_table_tensor[:, 0] spec_query_start_loc = None non_spec_query_start_loc = query_start_loc + non_spec_query_start_loc_cpu = query_start_loc_cpu num_accepted_tokens = None else: query_lens = query_start_loc[1:] - query_start_loc[:-1] + assert spec_sequence_masks_cpu is not None + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] non_spec_query_lens = query_lens[~spec_sequence_masks] num_decodes = (non_spec_query_lens == 1).sum().item() @@ -219,6 +225,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] non_spec_state_indices_tensor = None spec_query_start_loc = query_start_loc non_spec_query_start_loc = None + non_spec_query_start_loc_cpu = None else: spec_token_masks = torch.repeat_interleave( spec_sequence_masks, query_lens @@ -253,6 +260,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] dim=0, out=non_spec_query_start_loc[1:], ) + non_spec_query_start_loc_cpu = torch.zeros( + query_lens_cpu.size(0) - num_spec_decodes + 1, + dtype=torch.int32, + ) + torch.cumsum( + query_lens_cpu[~spec_sequence_masks_cpu], + dim=0, + out=non_spec_query_start_loc_cpu[1:], + ) assert num_accepted_tokens is not None num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] @@ -261,8 +277,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] has_initial_state = context_lens_tensor > 0 if spec_sequence_masks is not None: has_initial_state = has_initial_state[~spec_sequence_masks] + assert non_spec_query_start_loc_cpu is not None nums_dict, batch_ptr, token_chunk_offset_ptr = ( - compute_causal_conv1d_metadata(non_spec_query_start_loc) + compute_causal_conv1d_metadata( + non_spec_query_start_loc_cpu, + device=query_start_loc.device, + ) ) else: has_initial_state = None diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 0c55877a5..9bf87d1b2 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -219,21 +219,24 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): if num_prefills > 0: if num_computed_tokens is None: num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() - num_computed_tokens_cpu = num_computed_tokens.cpu() + query_start_loc_p_cpu = ( + common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :] + - num_decode_tokens + ) query_start_loc_p = ( common_attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decode_tokens ) - has_initial_states_cpu = ( - num_computed_tokens_cpu[num_reqs - num_prefills : num_reqs] > 0 - ) - has_initial_states_p = has_initial_states_cpu.to( - common_attn_metadata.query_start_loc.device + has_initial_states_p = ( + num_computed_tokens[num_reqs - num_prefills : num_reqs] > 0 ) nums_dict, batch_ptr, token_chunk_offset_ptr = ( - compute_causal_conv1d_metadata(query_start_loc_p) + compute_causal_conv1d_metadata( + query_start_loc_p_cpu, + device=common_attn_metadata.query_start_loc.device, + ) ) if self.vllm_config.cache_config.enable_prefix_caching: diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 82321c000..74392032e 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -732,13 +732,17 @@ def create_fast_prefill_custom_backend( return attn_backend -def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): - # Needed for causal_conv1d - seqlens = query_start_loc_p.diff().to("cpu") +def compute_causal_conv1d_metadata( + query_start_loc_p_cpu: torch.Tensor, + *, + device: torch.device, +): + # Needed for causal_conv1d. Use the CPU query_start_loc to avoid DtoH sync. + assert query_start_loc_p_cpu.device.type == "cpu" + seqlens = query_start_loc_p_cpu.diff() nums_dict = {} # type: ignore batch_ptr = None token_chunk_offset_ptr = None - device = query_start_loc_p.device for BLOCK_M in [8]: # cover all BLOCK_M values nums = -(-seqlens // BLOCK_M) nums_dict[BLOCK_M] = {}