# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations from collections.abc import Iterable, Iterator, Mapping, Sequence from typing import Annotated, Any import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from transformers.image_processing_utils import BatchFeature from vllm.config import ModelConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, ) from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE, SupportsMultiModal, SupportsPP, ) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.siglip import SiglipMLP from vllm.model_executor.models.utils import ( AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, ) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import ( BaseDummyInputsBuilder, BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails, ) from vllm.sequence import IntermediateTensors from vllm.tokenizers import cached_tokenizer_from_config from vllm.transformers_utils.config import patch_rope_parameters from vllm.transformers_utils.configs.isaac import ( IsaacConfig, PixelShuffleSiglip2VisionConfig, ) from vllm.transformers_utils.processors.isaac import ( IsaacImageProcessor, IsaacProcessor, get_image_size_for_max_num_patches, ) from vllm.utils.tensor_schema import TensorSchema, TensorShape from .vision import is_vit_use_data_parallel def create_cumulative_seq_lengths( seq_sizes: torch.Tensor, device: torch.device ) -> tuple[torch.Tensor, torch.Tensor]: """Create cumulative sequence lengths for variable-length attention.""" cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device) cu_seqlens[1:] = seq_sizes.cumsum(0) max_seqlen = ( seq_sizes.max() if len(seq_sizes) > 0 else torch.tensor(0, dtype=torch.int32, device=device) ) return cu_seqlens, max_seqlen class Siglip2VariableSequenceEmbeddings(nn.Module): def __init__(self, config: PixelShuffleSiglip2VisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.patch_size = config.patch_size self.patch_embedding = ReplicatedLinear( input_size=config.num_channels * self.patch_size * self.patch_size, output_size=self.embed_dim, return_bias=False, ) self.num_patches = config.num_patches self.position_embedding_size = int(self.num_patches**0.5) self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) def positional_embeddings( self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor] ) -> torch.Tensor: # Prepare positional embeddings grid: (1, embed_dim, h, w) positional_embeddings = ( self.position_embedding.weight.reshape( self.position_embedding_size, self.position_embedding_size, -1 ) .permute(2, 0, 1) .unsqueeze(0) ) _seq_patches, _seq_sizes, spatial_shapes = packed_seq_patches pos_embeds_list = [] mode = "bilinear" align_corners = False antialias = True for spatial_shape in spatial_shapes: height, width = int(spatial_shape[0]), int(spatial_shape[1]) # Guard to ensure height and width are positive for torch.compile if height > 0 and width > 0: resized_pos_embed = F.interpolate( positional_embeddings, size=(height, width), mode=mode, align_corners=align_corners, antialias=antialias, ) # Reshape from (1, embed_dim, height, width) to # (height*width, embed_dim) resized_pos_embed = resized_pos_embed.reshape( self.embed_dim, height * width ).transpose(0, 1) else: # Fallback - should never happen in practice resized_pos_embed = positional_embeddings.reshape( self.embed_dim, self.position_embedding_size * self.position_embedding_size, ).transpose(0, 1)[: height * width] pos_embeds_list.append(resized_pos_embed) # Concatenate all positional embeddings along the sequence dimension pos_embeds = torch.cat(pos_embeds_list, dim=0) return pos_embeds def forward( self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor] ): seq_patches, _seq_sizes, _spatial_shapes = packed_seq_patches target_weight = self.patch_embedding.weight seq_patches = seq_patches.to( device=target_weight.device, dtype=target_weight.dtype ) patch_embeds = self.patch_embedding(seq_patches) pos_embeds = self.positional_embeddings(packed_seq_patches) # Flatten patch embeddings to match positional embeddings format if patch_embeds.dim() == 3: patch_embeds = patch_embeds.view(-1, patch_embeds.size(-1)) # Add positional embeddings to patch embeddings embeddings = patch_embeds + pos_embeds return embeddings def create_pixel_shuffle_index_map( seq_sizes: torch.Tensor, token_grids: torch.Tensor, scale_factor: int = 1, device: torch.device | None = None, ) -> torch.Tensor: """ Build a gather-index map that tells us, for every *output* token after pixel-shuffle, which `scale_factor**2` *input* tokens are being merged. Args ---- seq_sizes : (num_images,) - #patches in each image (row-major order) token_grids : (num_images,2) - (height, width) for every image scale_factor : spatial down-scale factor (≥2) device : (optional) overrides `seq_sizes.device` Returns ------- gather_idx : (new_total_seq_len, scale_factor**2) int64 tensor. gather_idx[i, j] is the *flat* index into the *original* packed sequence for the j-th sub-patch that forms the i-th output token. """ if device is None: device = seq_sizes.device r = int(scale_factor) if r < 2: raise ValueError("`scale_factor` must be ≥ 2") # Safety: all spatial dims must be divisible by r # Cannot run under torch compile fullgraph mode hence if not torch.compiler.is_compiling() and not ( (token_grids[:, 0] % r == 0).all() and (token_grids[:, 1] % r == 0).all() ): raise AssertionError( "Every (H,W) in `token_grids` must be divisible by " f"scale_factor={r}, got {token_grids.tolist()}" ) gather_chunks: list[torch.Tensor] = [] tok_offset = 0 for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist(), strict=False): # Build the (H, W) grid of flat indices for this image grid = torch.arange(seq_len, device=device, dtype=torch.int64) + tok_offset grid = grid.view(h, w) # (H, W) # -------- identical ordering to your fixed-res routine -------- # Step 1: split width into blocks of r grid = grid.view(h, w // r, r) # (H, W/r, r) # Step 2: now split height into blocks of r grid = grid.view(h // r, r, w // r, r) # (H/r, r, W/r, r) # Step 3: final permutation to (H/r, W/r, r, r) grid = grid.permute(0, 2, 1, 3).contiguous() # (H/r, W/r, r, r) # Step 4: each (r, r) block forms one output token gather_chunks.append(grid.reshape(-1, r * r)) # (H*W / r², r²) tok_offset += seq_len # Concatenate over all images in the packed batch gather_idx = torch.cat(gather_chunks, dim=0) # (Σ_i HᵢWᵢ/r², r²) return gather_idx def pixel_shuffle_varlen( x: torch.Tensor, token_grids: torch.Tensor, scale_factor: int = 1, ) -> torch.Tensor: r"""Apply pixel shuffle to a packed vision sequence without unpacking per image. Args: x (`torch.Tensor`): Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes produced by stacking image patches. token_grids (`torch.Tensor`): Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes corresponding to each image segment inside `x`. scale_factor (`int`, *optional*, defaults to 1): Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a single embedding channel-group. Returns: `torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention: `(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` if the singleton batch dimension was present. Raises: ValueError: If more than one batch item is provided. """ keep_batch_dim = x.dim() == 3 if keep_batch_dim: if x.size(0) != 1: raise AssertionError("Packed sequence is expected to have batch_size == 1") x_ = x.squeeze(0) # (seq, embed) else: x_ = x # (seq, embed) embed_dim = x_.size(-1) r = int(scale_factor) # Calculate seq_sizes from token_grids seq_sizes = torch.prod(token_grids, dim=-1) # Build index map and gather in one go gather_idx = create_pixel_shuffle_index_map( seq_sizes=seq_sizes, token_grids=token_grids, scale_factor=r, device=x_.device, ) # (new_seq, r²) # Gather → (new_seq, r², embed_dim) gathered = x_[gather_idx] # fancy indexing keeps gradient # Merge the r² group dimension into channels to finish the shuffle out = gathered.reshape(gathered.size(0), embed_dim * r * r) # Restore batch dimension if needed if keep_batch_dim: out = out.unsqueeze(0) return out # ============================================================================ # Configuration # ============================================================================ class IsaacProcessingInfo(BaseProcessingInfo): def get_hf_config(self) -> IsaacConfig: if hasattr(self.ctx, "get_hf_config"): original_config = self.ctx.get_hf_config() # Map HF config parameters to our vLLM config parameters return IsaacConfig( # Vision parameters - map from HF names vision_config=getattr(original_config, "vision_config", None), vision_patch_size=getattr(original_config, "video_patch_size", 16), vision_max_num_patches=getattr( original_config, "vision_max_num_patches", 256 ), vision_min_num_patches=getattr( original_config, "vision_min_num_patches", None ), pixel_shuffle_scale=getattr(original_config, "pixel_shuffle_scale", 1), max_sequence_length=getattr( original_config, "max_sequence_length", 16384 ), vision_token=getattr(original_config, "vision_token", ""), vision_attn_implementation=getattr( original_config, "vision_attn_implementation", None ), ) return IsaacConfig() def get_image_processor(self, **kwargs) -> IsaacImageProcessor: return IsaacImageProcessor(**kwargs) def get_hf_processor(self, **kwargs) -> IsaacProcessor: hf_config = self.get_hf_config() return IsaacProcessor( tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor(**kwargs), image_token=hf_config.vision_token, ) def get_image_size_with_most_features(self) -> ImageSize: hf_config = self.get_hf_config() # Get target dimensions target_height, target_width = get_image_size_for_max_num_patches( 9999999, 9999999, hf_config.video_patch_size, hf_config.vision_max_num_patches, min_num_patches=hf_config.vision_min_num_patches, pixel_shuffle_scale=hf_config.pixel_shuffle_scale, ) return ImageSize(width=target_width, height=target_height) def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_mm_max_tokens_per_item( self, seq_len: int, mm_counts: Mapping[str, int], ) -> Mapping[str, int]: hf_config = self.get_hf_config() num_vision_tokens = hf_config.vision_max_num_patches // ( hf_config.pixel_shuffle_scale**2 ) return {"image": num_vision_tokens} class IsaacDummyInputsBuilder(BaseDummyInputsBuilder[IsaacProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) hf_processor = self.info.get_hf_processor() image_token: str = hf_processor.image_token return image_token * num_images def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions], ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) target_width, target_height = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") return { "image": self._get_dummy_images( width=target_width, height=target_height, num_images=num_images, overrides=image_overrides, ), } class IsaacImagePixelInputs(TensorSchema): """ Schema for validating Isaac image inputs. Dimensions: - np: Number of patches - d: Patch dimension - ni: Number of images The schema enforces: - pixel_values must be 2D: (num_patches, patch_dim) - image_grid_thw must be 2D: (num_images, 3) where 3 represents [T, H, W] """ pixel_values: Annotated[ torch.Tensor, TensorShape("np", "d"), ] image_grid_thw: Annotated[ torch.Tensor, TensorShape("ni", 3), ] class IsaacMultiModalProcessor(BaseMultiModalProcessor): def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: # Configure multimodal fields for Isaac model image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) image_grid_sizes = image_grid_thw.prod(-1) return { "pixel_values": MultiModalFieldConfig.flat_from_sizes( "image", image_grid_sizes ), "image_grid_thw": MultiModalFieldConfig.batched("image"), } def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) pixel_shuffle_scale = getattr(image_processor, "pixel_shuffle_scale", 2) merge_length = pixel_shuffle_scale**2 def get_replacement_isaac(item_idx: int): out_item = out_mm_kwargs["image"][item_idx] grid_thw = out_item["image_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) feature_size = int(grid_thw.prod()) // merge_length repl_full = "<|image_pad|>" * feature_size return PromptUpdateDetails.select_text(repl_full, "<|image_pad|>") return [ PromptReplacement( modality="image", target="", replacement=get_replacement_isaac, ) ] class Siglip2VisionAttention(nn.Module): def __init__( self, config: PixelShuffleSiglip2VisionConfig, quant_config: QuantizationConfig | None = None, *, prefix: str = "", ) -> None: super().__init__() use_data_parallel = is_vit_use_data_parallel() self.tp_size = ( 1 if use_data_parallel else parallel_state.get_tensor_model_parallel_world_size() ) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( config.hidden_size, config.num_attention_heads ) self.num_attention_heads_per_partition = dist_utils.divide( config.num_attention_heads, self.tp_size ) self.qkv_proj = QKVParallelLinear( hidden_size=config.hidden_size, head_size=self.hidden_size_per_attention_head, total_num_heads=config.num_attention_heads, total_num_kv_heads=config.num_attention_heads, bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", disable_tp=use_data_parallel, ) self.out_proj = RowParallelLinear( input_size=config.hidden_size, output_size=config.hidden_size, quant_config=quant_config, prefix=f"{prefix}.out_proj", disable_tp=use_data_parallel, ) self.attn = MMEncoderAttention( num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, scale=self.hidden_size_per_attention_head**-0.5, prefix=f"{prefix}.attn", ) def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: seq_len, bs, _ = qkv.shape q, k, v = qkv.chunk(3, dim=2) new_shape = ( seq_len, bs, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) q, k, v = (x.view(*new_shape) for x in (q, k, v)) return q, k, v def forward( self, hidden_states: torch.Tensor, *, cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor | None, ) -> torch.Tensor: batch_size, _, _ = hidden_states.shape if batch_size != 1: raise ValueError("packed variable-length attention expects batch_size=1") x = rearrange(hidden_states, "b s d -> s b d") x, _ = self.qkv_proj(x) q, k, v = self.split_qkv(x) q, k, v = (rearrange(t, "s b h d -> b s h d") for t in (q, k, v)) context_layer = self.attn( query=q, key=k, value=v, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() output, _ = self.out_proj(context_layer) output = rearrange(output, "s b d -> b s d") return output class Siglip2EncoderLayer(nn.Module): def __init__( self, config: PixelShuffleSiglip2VisionConfig, quant_config: QuantizationConfig | None = None, *, prefix: str = "", ) -> None: super().__init__() self.embed_dim = config.hidden_size self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.self_attn = Siglip2VisionAttention( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( config, quant_config=quant_config, prefix=f"{prefix}.mlp", ) def forward( self, hidden_states: torch.Tensor, *, cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor | None, ) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class Siglip2Encoder(nn.Module): def __init__( self, config: PixelShuffleSiglip2VisionConfig, quant_config: QuantizationConfig | None = None, *, prefix: str = "", ) -> None: super().__init__() self.config = config self.layers = nn.ModuleList( [ Siglip2EncoderLayer( config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", ) for layer_idx in range(config.num_hidden_layers) ] ) def forward( self, inputs_embeds: torch.Tensor, *, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, ) -> torch.Tensor: hidden_states = inputs_embeds for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) return hidden_states class Siglip2VisionTransformer(nn.Module): def __init__( self, config: PixelShuffleSiglip2VisionConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.config = config self.quant_config = quant_config embed_dim = config.hidden_size self.embeddings = Siglip2VariableSequenceEmbeddings(config) self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor self.encoder = Siglip2Encoder( config, quant_config=quant_config, prefix=f"{prefix}.encoder", ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor], ) -> torch.Tensor: r""" spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): Tensor containing the spatial dimensions (height, width) of the input images. """ seq_patches, token_grids = packed_seq_patches seq_sizes = torch.prod(token_grids, dim=-1) # Get embeddings from packed sequence hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids)) # Add a pseudo batch dimension for the encoder hidden_states = hidden_states.unsqueeze(0) cu_seqlens, max_seqlen = create_cumulative_seq_lengths( seq_sizes, hidden_states.device ) hidden_states = self.encoder( inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) hidden_states = self.post_layernorm(hidden_states) if self.pixel_shuffle_scale_factor > 1: hidden_states = pixel_shuffle_varlen( x=hidden_states, token_grids=token_grids, scale_factor=self.pixel_shuffle_scale_factor, ) # Remove the pseudo batch dimension we added earlier hidden_states = hidden_states.squeeze(0) # return last_hidden_state return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params def _resolve_vision_token_id(model_config: ModelConfig, vision_token: str) -> int: tokenizer = cached_tokenizer_from_config(model_config) assert tokenizer is not None return tokenizer.encode(vision_token, add_special_tokens=False)[0] class IsaacVisionEmbedding(nn.Module): def __init__( self, vision_cfg: PixelShuffleSiglip2VisionConfig, hidden_dim: int, output_dim: int, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.transformer = Siglip2VisionTransformer( vision_cfg, quant_config=quant_config, prefix=maybe_prefix(prefix, "0"), ) self.linear_fc1 = ColumnParallelLinear( hidden_dim, 4 * hidden_dim, bias=False, quant_config=quant_config, prefix=maybe_prefix(prefix, "1"), return_bias=False, ) self.act = nn.SiLU() self.linear_fc2 = RowParallelLinear( 4 * hidden_dim, output_dim, bias=False, quant_config=quant_config, prefix=maybe_prefix(prefix, "3"), return_bias=False, ) def forward( self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor] ) -> torch.Tensor: hidden_states = self.transformer(packed_seq_patches) hidden_states = self.linear_fc1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.linear_fc2(hidden_states) return hidden_states @MULTIMODAL_REGISTRY.register_processor( IsaacMultiModalProcessor, info=IsaacProcessingInfo, dummy_inputs=IsaacDummyInputsBuilder, ) class IsaacForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } supports_encoder_tp_data = True # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "lm_head.": "language_model.lm_head.", "model.text_model.lm_head.": "language_model.lm_head.", "model.text_model.": "language_model.model.", "model.vision_embedding.0": "vision_embedding.transformer", "model.vision_embedding.1": "vision_embedding.linear_fc1", "model.vision_embedding.2": "vision_embedding.act", "model.vision_embedding.3": "vision_embedding.linear_fc2", "model.vision_embedding.": "vision_embedding.", "model.lm_head.": "language_model.lm_head.", "model.": "language_model.model.", } ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "" raise ValueError("Only image modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): super().__init__() config: IsaacConfig = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config head_dim = config.head_dim calculated_mrope_section = [ head_dim // 4, # 2x more for temporal dim head_dim // 8, head_dim // 8, ] self.vision_token_id = _resolve_vision_token_id( vllm_config.model_config, config.vision_token ) config.image_token_id = self.vision_token_id text_cfg = getattr(config, "text_config", None) target_cfg = ( text_cfg if text_cfg is not None and not isinstance(text_cfg, dict) else config ) rope_scaling = getattr(target_cfg, "rope_scaling", None) if rope_scaling is None and target_cfg is config: rope_scaling = getattr(config, "_rope_scaling", None) patch_rope_parameters(target_cfg) rope_parameters = target_cfg.rope_parameters rope_parameters["mrope_section"] = calculated_mrope_section if rope_scaling is not None and "mrope_interleaved" in rope_scaling: rope_parameters.setdefault( "mrope_interleaved", rope_scaling["mrope_interleaved"] ) target_cfg.rope_parameters = rope_parameters with self._mark_language_model(vllm_config): self.language_model = init_vllm_registered_model( vllm_config=vllm_config, architectures=["Qwen3ForCausalLM"], prefix=maybe_prefix(prefix, "language_model"), ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors ) vision_cfg = config.vision_config if vision_cfg is None: raise ValueError("IsaacConfig should always have vision_config") attn_impl = ( config.vision_attn_implementation if config.vision_attn_implementation is not None else getattr(config, "_attn_implementation", None) ) if attn_impl is not None: vision_cfg._attn_implementation = attn_impl hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) with self._mark_tower_model(vllm_config, "image"): self.vision_embedding = IsaacVisionEmbedding( vision_cfg=vision_cfg, hidden_dim=hidden_dim, output_dim=config.hidden_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "vision_embedding"), ) def iter_mm_grid_hw( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec] ) -> Iterator[tuple[int, int, int]]: spatial_merge_size = self.config.vision_config.pixel_shuffle_scale_factor for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): offset = mm_feature.mm_position.offset if mm_feature.modality == "image": t, h, w = mm_feature.data["image_grid_thw"].data.tolist() assert t == 1, f"Image must have 1 frame, got {t}" yield offset, h // spatial_merge_size, w // spatial_merge_size else: raise ValueError(f"Unsupported modality: {mm_feature.modality}") def get_mrope_input_positions( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: llm_pos_ids_list = [] st = 0 for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw( input_tokens, mm_features ): text_len = offset - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) grid_indices[0, :] = grid_indices[0, :] + text_len + st_idx llm_pos_ids_list.append(grid_indices) st = offset + llm_grid_h * llm_grid_w if st < len(input_tokens): st_idx = llm_pos_ids_list[-1][0, -1] + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return torch.from_numpy(llm_positions), mrope_position_delta def _parse_and_validate_image_input( self, **kwargs: object ) -> IsaacImagePixelInputs | None: pixel_values = kwargs.get("pixel_values") image_grid_thw = kwargs.get("image_grid_thw") if pixel_values is None or image_grid_thw is None: return None # TensorSchema will automatically validate shapes on initialization return IsaacImagePixelInputs( pixel_values=pixel_values, image_grid_thw=image_grid_thw, ) def _process_image_input( self, image_input: IsaacImagePixelInputs, ) -> tuple[torch.Tensor, ...]: pixel_values = image_input["pixel_values"] image_grid_thw = image_input["image_grid_thw"] if pixel_values.numel() == 0: return () device = next(self.language_model.parameters()).device dtype = self.vision_embedding.linear_fc1.weight.dtype pixel_values = pixel_values.to(device=device, dtype=dtype) spatial_grids = image_grid_thw[:, 1:3].to(device, dtype=torch.int32) vision_embeddings = self.vision_embedding((pixel_values, spatial_grids)) merge_size = self.config.vision_config.pixel_shuffle_scale_factor sizes = spatial_grids.prod(-1) // (merge_size * merge_size) return tuple(vision_embeddings.split(sizes.tolist())) def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return () return self._process_image_input(image_input) def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor | IntermediateTensors: return self.language_model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, **kwargs, ) def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models """ return MultiModelKeys.from_string_field( language_model="language_model", connector="vision_embedding.linear_fc2", # The final linear layer tower_model="vision_embedding", )