diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 45465d9c4..a96abd891 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -689,6 +689,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | | `IsaacForConditionalGeneration` | Isaac | T + I+ | `PerceptronAI/Isaac-0.1` | ✅︎ | ✅︎ | | `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ | +| `InternS1ProForConditionalGeneration` | Intern-S1-Pro | T + IE+ + VE+ | `internlm/Intern-S1-Pro`, etc. | ✅︎ | ✅︎ | | `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | | `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + IE+ + VE+ | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | | `KananaVForConditionalGeneration` | Kanana-V | T + I+ | `kakaocorp/kanana-1.5-v-3b-instruct`, etc. | | ✅︎ | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index dd442d9e3..d0122b318 100755 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -842,6 +842,40 @@ def run_interns1(questions: list[str], modality: str) -> ModelRequestData: ) +# Intern-S1-Pro +def run_interns1_pro(questions: list[str], modality: str) -> ModelRequestData: + model_name = "internlm/Intern-S1-Pro" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + max_num_seqs=2, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + tensor_parallel_size=4, + ) + + if modality == "image": + placeholder = "<|vision_start|><|image_pad|><|vision_end|>" + elif modality == "video": + placeholder = "<|vision_start|><|video_pad|><|vision_end|>" + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + messages = [ + [{"role": "user", "content": f"{placeholder}\n{question}"}] + for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # InternVL def run_internvl(questions: list[str], modality: str) -> ModelRequestData: model_name = "OpenGVLab/InternVL3-2B" @@ -2130,6 +2164,7 @@ model_example_map = { "hyperclovax_seed_vision": run_hyperclovax_seed_vision, "idefics3": run_idefics3, "interns1": run_interns1, + "interns1_pro": run_interns1_pro, "internvl_chat": run_internvl, "kanana_v": run_kanana_v, "keye_vl": run_keye_vl, diff --git a/tests/models/registry.py b/tests/models/registry.py index 0e3d0d312..c38637c1c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -755,6 +755,12 @@ _MULTIMODAL_EXAMPLE_MODELS = { "InternS1ForConditionalGeneration": _HfExamplesInfo( "internlm/Intern-S1", trust_remote_code=True ), + "InternS1ProForConditionalGeneration": _HfExamplesInfo( + "internlm/Intern-S1-Pro", + trust_remote_code=True, + min_transformers_version="5.0.0", + is_available_online=False, + ), "InternVLChatModel": _HfExamplesInfo( "OpenGVLab/InternVL2-1B", extras={ diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index 127d84555..9ad7c9cda 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -11,6 +11,7 @@ from .deepseek_scaling_rope import DeepseekScalingRotaryEmbedding from .dual_chunk_rope import DualChunkRotaryEmbedding from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding +from .fope import FourierRotaryEmbedding from .linear_scaling_rope import LinearScalingRotaryEmbedding from .llama3_rope import Llama3RotaryEmbedding from .llama4_vision_rope import Llama4VisionRotaryEmbedding @@ -102,6 +103,28 @@ def get_rope( mrope_section=rope_parameters["mrope_section"], mrope_interleaved=rope_parameters.get("mrope_interleaved", False), ) + elif "use_fope" in rope_parameters and rope_parameters["use_fope"]: + extra_kwargs = { + k: v + for k, v in rope_parameters.items() + if k + in ( + "num_key_value_heads", + "num_inv_freq", + "fope_sep_head", + "fope_init_factor", + ) + } + extra_kwargs["init_cache"] = False + rotary_emb = FourierRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + **extra_kwargs, + ) else: rotary_emb = RotaryEmbedding( head_size, diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index ffc6f67da..2147e00d2 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -25,6 +25,7 @@ class RotaryEmbeddingBase(CustomOp): base: float, is_neox_style: bool, dtype: torch.dtype, + init_cache: bool = True, ) -> None: super().__init__() self.head_size = head_size @@ -46,11 +47,12 @@ class RotaryEmbeddingBase(CustomOp): if not hasattr(self, "use_flashinfer"): self.use_flashinfer = False - cache = self._compute_cos_sin_cache() - if not self.use_flashinfer: - cache = cache.to(dtype) - self.cos_sin_cache: torch.Tensor - self.register_buffer("cos_sin_cache", cache, persistent=False) + if init_cache: + cache = self._compute_cos_sin_cache() + if not self.use_flashinfer: + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) self.is_rocm_triton_rotary_embed_enabled = ( rocm_aiter_ops.is_triton_rotary_embed_enabled() ) @@ -108,9 +110,16 @@ class RotaryEmbedding(RotaryEmbeddingBase): base: float, is_neox_style: bool, dtype: torch.dtype, + init_cache: bool = True, ) -> None: super().__init__( - head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + head_size=head_size, + rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + base=base, + is_neox_style=is_neox_style, + dtype=dtype, + init_cache=init_cache, ) @staticmethod diff --git a/vllm/model_executor/layers/rotary_embedding/fope.py b/vllm/model_executor/layers/rotary_embedding/fope.py new file mode 100644 index 000000000..4c8a7bcbf --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/fope.py @@ -0,0 +1,199 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +import torch.nn.functional as F +from torch import nn + +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +from .base import RotaryEmbedding +from .common import rotate_neox + + +class FourierRotaryEmbedding(RotaryEmbedding): + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + init_cache: bool, + # extra parameters for FoPE + num_key_value_heads: int, + num_inv_freq: int, + fope_sep_head: bool, + fope_init_factor: float, + ): + # fope related parameters + self.num_key_value_heads = num_key_value_heads + self.num_inv_freq = num_inv_freq + self.fope_sep_head = fope_sep_head + self.fope_init_factor = fope_init_factor + + super().__init__( + head_size=head_size, + rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + base=base, + is_neox_style=is_neox_style, + dtype=dtype, + init_cache=init_cache, + ) + + # setup buffers and parameters + self.inv_freq: torch.Tensor + self.register_buffer( + "inv_freq", self._compute_inv_freq(self.base), persistent=False + ) + + self.input_dim = self.inv_freq.shape[-1] + self.output_dim = self.inv_freq.shape[-1] + self.cos_coef = nn.Parameter( + torch.empty(num_key_value_heads, self.input_dim, self.output_dim), + requires_grad=False, + ) + self.sin_coef = nn.Parameter( + torch.empty(num_key_value_heads, self.input_dim, self.output_dim), + requires_grad=False, + ) + self.sin_coef.weight_loader = self.weight_loader + self.cos_coef.weight_loader = self.weight_loader + + self.cos_sin_cache: torch.Tensor + cache = self._compute_cos_sin_cache().to(dtype) + self.register_buffer("cos_sin_cache", cache, persistent=False) + + # update cache in the first forward, where sin/cos_coef weights are ready + self.update_cache = True + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + """Compute the inverse frequency.""" + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + + inv_freq_idx_selected = torch.ones_like(inv_freq, dtype=torch.bool) + if self.num_inv_freq is not None: + inv_freq_idx_selected[self.num_inv_freq :] = False + else: + inv_freq_idx_selected = inv_freq > ( + 2.0 * torch.pi / self.max_position_embeddings + ) + + inv_freq = inv_freq[inv_freq_idx_selected] + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + device = self.inv_freq.device + t = torch.arange(self.max_position_embeddings, dtype=torch.float, device=device) + + freqs = torch.einsum("j,i -> ji", t, self.inv_freq) + if self.fope_sep_head: + pos_cos = freqs.cos().unsqueeze(0).expand(self.num_key_value_heads, -1, -1) + pos_sin = freqs.sin().unsqueeze(0).expand(self.num_key_value_heads, -1, -1) + else: + pos_cos = freqs.cos() + pos_sin = freqs.sin() + + if self.fope_sep_head: + sin = torch.einsum("htD, hDd -> thd", pos_sin, self.sin_coef.float()) + cos = torch.einsum("htD, hDd -> thd", pos_cos, self.cos_coef.float()) + else: + sin = torch.einsum("tD, Dd -> td", pos_sin, self.sin_coef.float()) + cos = torch.einsum("tD, Dd -> td", pos_cos, self.cos_coef.float()) + + sin = F.pad( + input=sin, + pad=(0, self.head_size // 2 - sin.size(-1)), + mode="constant", + value=1, + ) + cos = F.pad( + input=cos, + pad=(0, self.head_size // 2 - cos.size(-1)), + mode="constant", + value=1, + ) + + sin = torch.cat((sin, sin), dim=-1) + cos = torch.cat((cos, cos), dim=-1) + + # cache: (max_position_embeddings, num_kv_heads, kv_size * 2) + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # update cos/sin cache in the first forward + if self.update_cache: + cache = self._compute_cos_sin_cache().to(self.dtype) + self.cos_sin_cache.copy_(cache) + self.update_cache = False + + positions = positions.flatten() + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + # apply rotary embedding + # query: (seq_len, num_heads, head_size) + # key: (seq_len, num_kv_heads, head_size) + query = query.unflatten(-1, (-1, self.head_size)) + assert key is not None, "Key tensor is required for FoPE." + key = key.unflatten(-1, (-1, self.head_size)) + + assert query.dim() == key.dim() == 3, ( + "Expected query key (seq_len, heads, head_dim)" + ) + assert cos.dim() <= 3 and sin.dim() <= 3 + + need_reshape = False + if cos.dim() == 3: + # for fope + need_reshape = True + query_shape = query.shape + key_shape = key.shape + cos = cos.flatten(0, 1) + sin = sin.flatten(0, 1) + seq_len = cos.size(0) + query = query.view(seq_len, -1, query.size(-1)) + key = key.view(seq_len, -1, key.size(-1)) + + # native implementation of apply rope for neox style + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + query = (query * cos) + (rotate_neox(query) * sin) + key = (key * cos) + (rotate_neox(key) * sin) + + if need_reshape: + query = query.view(query_shape) + key = key.view(key_shape) + + return query, key + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + """load fope weights""" + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + num_key_value_heads = loaded_weight.size(0) + + if num_key_value_heads < world_size: + n_replicate = world_size // num_key_value_heads + world_size = num_key_value_heads + rank = rank // n_replicate + + loaded_weight = loaded_weight.chunk(world_size, dim=0)[rank] + param.data.copy_(loaded_weight) diff --git a/vllm/model_executor/models/interns1_pro.py b/vllm/model_executor/models/interns1_pro.py new file mode 100644 index 000000000..60c92cdda --- /dev/null +++ b/vllm/model_executor/models/interns1_pro.py @@ -0,0 +1,633 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 The Qwen Team. +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only InternS1Pro model compatible with HuggingFace weights.""" + +import functools +from collections.abc import Iterable +from typing import Any + +import torch +from torch import nn +from transformers import AutoProcessor, PretrainedConfig + +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import ( + get_ep_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, +) +from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.multimodal import MULTIMODAL_REGISTRY + +from .interfaces import MixtureOfExperts +from .qwen3_moe import ( + Qwen3MoeForCausalLM, +) +from .qwen3_vl import ( + Qwen3_VisionTransformer, + Qwen3VLDummyInputsBuilder, + Qwen3VLForConditionalGeneration, + Qwen3VLMultiModalProcessor, + Qwen3VLProcessingInfo, +) +from .qwen3_vl_moe import Qwen3MoeLLMModel +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + extract_layer_index, + maybe_prefix, +) + +logger = init_logger(__name__) + + +class InternS1ProProcessingInfo(Qwen3VLProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs: object) -> AutoProcessor: + return AutoProcessor.from_pretrained( + self.ctx.model_config.model, + trust_remote_code=True, + **kwargs, + ) + + +class InternS1ProMoeMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class InternS1ProMoeSparseMoeBlock(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + + config = vllm_config.model_config.hf_text_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + + self.tp_size = get_tensor_model_parallel_world_size() + + self.ep_group = get_ep_group().device_group + self.ep_rank = get_ep_group().rank_in_group + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}." + ) + + # Load balancing settings. + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + # For custom routing function + self.n_groups = getattr(config, "router_n_groups", -1) + + self.experts = FusedMoE( + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=True, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + routing_method_type=RoutingMethodType.Renormalize, + custom_routing_function=self._custom_routing_function, + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + prefix=f"{prefix}.gate", + ) + + @staticmethod + @functools.lru_cache + def get_group_offsets(n_groups: int, group_size: int, device: str): + group_offsets = (torch.arange(n_groups, device=device) * group_size).view( + 1, -1, 1 + ) # [1, n_groups, 1] + return group_offsets + + # TODO: zhouxinyu, use vllm routing functions + def _custom_routing_function( + self, + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + ) -> torch.Tensor: + routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32) + + if self.n_groups > 0: + assert routing_weights.shape[-1] % self.n_groups == 0, ( + f"{routing_weights.shape[-1]} cannot be divided by {self.n_groups}" + ) + per_group_top_k = topk // self.n_groups + group_size = routing_weights.shape[-1] // self.n_groups + group_offsets = self.get_group_offsets( + self.n_groups, group_size, routing_weights.device + ) + routing_weights = routing_weights.unflatten(-1, (self.n_groups, group_size)) + topk_weights, topk_ids = torch.topk( + routing_weights, per_group_top_k, dim=-1 + ) + topk_ids = (topk_ids + group_offsets).flatten(-2, -1) + topk_weights = topk_weights.flatten(-2, -1) + else: + topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + assert hidden_states.dim() <= 2, ( + "InternS1ProMoeSparseMoeBlock only supports 1D or 2D inputs" + ) + is_input_1d = hidden_states.dim() == 1 + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + final_hidden_states = final_hidden_states[:num_tokens] + + # return to 1d if input is 1d + return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states + + +class InternS1ProMoeAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_parameters: dict[str, Any], + max_position_embeddings: int = 32768, + head_dim: int | None = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + dual_chunk_attention_config: dict[str, Any] | None = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.max_position_embeddings = max_position_embeddings + self.dual_chunk_attention_config = dual_chunk_attention_config + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + rope_parameters["num_key_value_heads"] = self.num_kv_heads + self.rotary_emb = get_rope( + self.head_dim, + max_position=max_position_embeddings, + rope_parameters=rope_parameters, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": dual_chunk_attention_config, + } + if dual_chunk_attention_config + else {}, + ) + + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # Add qk-norm + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb.forward_native(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class InternS1ProMoeDecoderLayer(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + + config = vllm_config.model_config.hf_text_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.hidden_size = config.hidden_size + max_position_embeddings = getattr(config, "max_position_embeddings", 32768) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) + + # update rope related parameters + rope_scaling = config.rope_scaling + fope_keys = {"fope_init_factor", "fope_sep_head", "num_inv_freq"} + use_fope = any(rope_scaling.get(key) is not None for key in fope_keys) + fope_init_factor = rope_scaling.get("fope_init_factor", None) + fope_sep_head = rope_scaling.get("fope_sep_head", None) + num_inv_freq = rope_scaling.get("num_inv_freq", None) + + config.rope_parameters["use_fope"] = use_fope + config.rope_parameters["fope_init_factor"] = fope_init_factor + config.rope_parameters["fope_sep_head"] = fope_sep_head + config.rope_parameters["num_inv_freq"] = num_inv_freq + + assert use_fope, "should use FOPE for InternS1Pro model" + self.self_attn = InternS1ProMoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_parameters=config.rope_parameters, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, "attention_bias", False), + head_dim=getattr(config, "head_dim", None), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + dual_chunk_attention_config=dual_chunk_attention_config, + ) + + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) + if (layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = InternS1ProMoeSparseMoeBlock( + vllm_config=vllm_config, prefix=f"{prefix}.mlp" + ) + else: + self.mlp = InternS1ProMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class InternS1ProMoeLLMModel(Qwen3MoeLLMModel): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[torch.nn.Module] = InternS1ProMoeDecoderLayer, + ): + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=decoder_layer_type, + ) + + +class InternS1ProMoeLLMForCausalLM(Qwen3MoeForCausalLM): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config.text_config + self.quant_config = vllm_config.quant_config + self.model = InternS1ProMoeLLMModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(self.config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + +class Qwen3VLMoeMixtureOfExperts(MixtureOfExperts): + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.language_model.model.layers: + if isinstance(layer.mlp, InternS1ProMoeSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def set_moe_parameters(self): + self.expert_weights = [] + + self.moe_layers = [] + example_moe = None + for layer in self.language_model.model.layers: + if hasattr(layer, "mlp") and isinstance( + layer.mlp, InternS1ProMoeSparseMoeBlock + ): + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None: + raise RuntimeError("No InternS1ProMoe layer found in the language_model.") + + # Set MoE hyperparameters + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen3VLMultiModalProcessor, + info=InternS1ProProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder, +) +class InternS1ProForConditionalGeneration( + Qwen3VLForConditionalGeneration, Qwen3VLMoeMixtureOfExperts +): + is_3d_moe_weight: bool = True + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.visual.": "visual.", + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + }, + orig_to_new_suffix={ + # Handle FOPE rotary embeddings + ".rotary_emb.sin_coef": ".layers.0.self_attn.rotary_emb.sin_coef", + ".rotary_emb.cos_coef": ".layers.0.self_attn.rotary_emb.cos_coef", + }, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: PretrainedConfig = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.video_pruning_rate = multimodal_config.video_pruning_rate + self.is_multimodal_pruning_enabled = ( + multimodal_config.is_multimodal_pruning_enabled() + ) + + if not multimodal_config.get_limit_per_prompt( + "image" + ) and not multimodal_config.get_limit_per_prompt("video"): + self.visual = None + else: + self.visual = Qwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + multimodal_config=multimodal_config, + prefix=maybe_prefix(prefix, "visual"), + ) + + self.language_model = InternS1ProMoeLLMForCausalLM( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") + ) + # Whether to include the gate_up_proj mapping is determined by + # the language model. + self.packed_modules_mapping = ( + self.packed_modules_mapping | self.language_model.packed_modules_mapping + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes") + self.deepstack_num_level = ( + len(config.vision_config.deepstack_visual_indexes) + if self.use_deepstack + else 0 + ) + self.visual_dim = config.vision_config.out_hidden_size + self.multiscale_dim = self.visual_dim * self.deepstack_num_level + + # Set MoE hyperparameters + self.set_moe_parameters() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + """load weights""" + skip_prefixes = ["model.time_series."] + if self.visual is None: + skip_prefixes.append("visual.") + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 2f95f4141..45aa58ab2 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -428,7 +428,13 @@ class Qwen3MoeDecoderLayer(nn.Module): @support_torch_compile class Qwen3MoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[torch.nn.Module] = Qwen3MoeDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_text_config @@ -449,7 +455,7 @@ class Qwen3MoeModel(nn.Module): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config, prefix=prefix), + lambda prefix: decoder_layer_type(vllm_config=vllm_config, prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 977548339..102d84609 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -325,7 +325,11 @@ class Qwen3_VisionTransformer(nn.Module): self.spatial_merge_size = vision_config.spatial_merge_size self.spatial_merge_unit = self.spatial_merge_size**2 self.temporal_patch_size = vision_config.temporal_patch_size - self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + self.deepstack_visual_indexes = ( + vision_config.deepstack_visual_indexes + if hasattr(vision_config, "deepstack_visual_indexes") + else [] + ) self.num_grid_per_side = int(self.num_position_embeddings**0.5) # NOTE: This is used for creating empty tensor for all_gather for diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index b39a3d297..af8536e3f 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -48,6 +48,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import MixtureOfExperts from .qwen3_moe import ( + Qwen3MoeDecoderLayer, Qwen3MoeForCausalLM, Qwen3MoeModel, Qwen3MoeSparseMoeBlock, @@ -82,8 +83,18 @@ class Qwen3VLMoeProcessingInfo(Qwen3VLProcessingInfo): } ) class Qwen3MoeLLMModel(Qwen3MoeModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[torch.nn.Module] = Qwen3MoeDecoderLayer, + ): + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=decoder_layer_type, + ) if not get_pp_group().is_first_rank: assert self.start_layer >= len( vllm_config.model_config.hf_config.vision_config.deepstack_visual_indexes diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index ed2a39d24..5eeb32ed9 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -357,6 +357,10 @@ _MULTIMODAL_MODELS = { "interns1", "InternS1ForConditionalGeneration", ), + "InternS1ProForConditionalGeneration": ( + "interns1_pro", + "InternS1ProForConditionalGeneration", + ), "Idefics3ForConditionalGeneration": ( "idefics3", "Idefics3ForConditionalGeneration",