diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 45099e992..c2734ce11 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -375,6 +375,7 @@ th { | `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | | `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | | `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | +| `ExaoneMoeCausalLM` | K-EXAONE | `LGAI-EXAONE/K-EXAONE-236B-A23B`, etc. | | | | `Exaone4ForCausalLM` | EXAONE-4 | `LGAI-EXAONE/EXAONE-4.0-32B`, etc. | ✅︎ | ✅︎ | | `Fairseq2LlamaForCausalLM` | Llama (fairseq2 format) | `mgleize/fairseq2-dummy-Llama-3.2-1B`, etc. | ✅︎ | ✅︎ | | `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index fc57e8799..54d18158f 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -250,6 +250,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", trust_remote_code=True ), "Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B"), + "ExaoneMoEForCausalLM": _HfExamplesInfo( + "LGAI-EXAONE/K-EXAONE-236B-A23B", min_transformers_version="5.0.0" + ), "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"), @@ -1005,6 +1008,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { trust_remote_code=True, speculative_model="baidu/ERNIE-4.5-21B-A3B-PT", ), + "ExaoneMoeMTP": _HfExamplesInfo( + "LGAI-EXAONE/K-EXAONE-236B-A23B", + speculative_model="LGAI-EXAONE/K-EXAONE-236B-A23B", + min_transformers_version="5.0.0", + ), "Glm4MoeMTPModel": _HfExamplesInfo( "zai-org/GLM-4.5", speculative_model="zai-org/GLM-4.5", diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index e80fa532a..cba00daaa 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -33,6 +33,7 @@ MTPModelTypes = Literal[ "mimo_mtp", "glm4_moe_mtp", "ernie_mtp", + "exaone_moe_mtp", "qwen3_next_mtp", "longcat_flash_mtp", "mtp", @@ -219,6 +220,15 @@ class SpeculativeConfig: hf_config.update( {"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]} ) + + if hf_config.model_type == "exaone_moe": + hf_config.model_type = "exaone_moe_mtp" + if hf_config.model_type == "exaone_moe_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update( + {"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]} + ) + if hf_config.model_type == "longcat_flash": hf_config.model_type = "longcat_flash_mtp" n_predict = getattr(hf_config, "num_nextn_predict_layers", 1) diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index b4b7a798f..cff82396a 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -72,6 +72,7 @@ class Exaone4GatedMLP(nn.Module): intermediate_size: int, hidden_act: str, quant_config: QuantizationConfig | None = None, + reduce_results: bool = True, bias: bool = False, prefix: str = "", ) -> None: @@ -88,6 +89,7 @@ class Exaone4GatedMLP(nn.Module): output_size=hidden_size, bias=bias, quant_config=quant_config, + reduce_results=reduce_results, prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": diff --git a/vllm/model_executor/models/exaone_moe.py b/vllm/model_executor/models/exaone_moe.py new file mode 100644 index 000000000..abaade8a3 --- /dev/null +++ b/vllm/model_executor/models/exaone_moe.py @@ -0,0 +1,578 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only K-EXAONE-236B-A22B model compatible with HuggingFace weights.""" + +import typing +from collections.abc import Callable, Iterable +from itertools import islice + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.sequence import IntermediateTensors + +from .exaone4 import Exaone4Attention as ExaoneMoeAttention +from .exaone4 import Exaone4GatedMLP as ExaoneMoeGatedMLP +from .interfaces import SupportsLoRA, SupportsPP +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + + +class ExaoneMoe(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + enable_eplb: bool = False, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + self.routed_scaling_factor = config.routed_scaling_factor + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}." + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.e_score_correction_bias = nn.Parameter( + torch.empty(config.num_experts, dtype=torch.float32) + ) + + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = enable_eplb + + self.n_logical_experts = self.n_routed_experts + eplb_config.num_redundant_experts = ( + eplb_config.num_redundant_experts + if eplb_config.num_redundant_experts is not None + else 0 + ) + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + self.experts = FusedMoE( + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func="sigmoid", + routed_scaling_factor=self.routed_scaling_factor, + e_score_correction_bias=self.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) + + if getattr(config, "num_shared_experts", 0) > 0: + intermediate_size = config.moe_intermediate_size * config.num_shared_experts + self.shared_experts = ExaoneMoeGatedMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=self.experts.must_reduce_shared_expert_outputs(), + prefix=f"{prefix}.shared_experts", + ) + else: + self.shared_experts = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 + final_hidden_states + ) + + return final_hidden_states.view(orig_shape) + + +class ExaoneMoeDecoderLayer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + mtp_layer: bool = None, + prefix: str = "", + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = config.hidden_size + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False + ) + + self.self_attn = ExaoneMoeAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + + if config.is_moe_layer[layer_idx] and not mtp_layer: + self.mlp = ExaoneMoe( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) + else: + self.mlp = ExaoneMoeGatedMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class ExaoneMoeModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.num_redundant_experts = ( + vllm_config.parallel_config.eplb_config.num_redundant_experts + ) + + self.config = config + self.quant_config = quant_config + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + self.vocab_size = config.vocab_size + lora_vocab + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: ExaoneMoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + # Skip loading extra parameters for GPTQ/modelopt models. + ignore_suffixes = ( + ".bias", + "_bias", + ".k_scale", + "_k_scale", + ".v_scale", + "_v_scale", + ".weight_scale", + "_weight_scale", + ".input_scale", + "_input_scale", + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if name.startswith("mtp."): + continue + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + # Skip loading extra parameters for GPTQ/modelopt models. + if ( + name_mapped.endswith(ignore_suffixes) + and name_mapped not in params_dict + ): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + name = name_mapped + break + else: + if is_expert_weight: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class ExaoneMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config.get_text_config() + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.model = ExaoneMoeModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config + else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + skip_prefixes=( + ["lm_head.", "mtp."] if self.config.tie_word_embeddings else ["mtp."] + ), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/exaone_moe_mtp.py b/vllm/model_executor/models/exaone_moe_mtp.py new file mode 100644 index 000000000..b3c71e6ae --- /dev/null +++ b/vllm/model_executor/models/exaone_moe_mtp.py @@ -0,0 +1,255 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only ExaoneMoe MTP model.""" + +from collections.abc import Iterable + +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.exaone_moe import ExaoneMoeDecoderLayer +from vllm.sequence import IntermediateTensors + +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + maybe_prefix, +) + +logger = init_logger(__name__) + +KVCache = tuple[torch.Tensor, torch.Tensor] + + +@support_torch_compile +class ExaoneMoeMultiTokenPredictor(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + model_config = vllm_config.model_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + config = model_config.hf_config + + self.config = config + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1) + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.fc = ColumnParallelLinear( + self.config.hidden_size * 2, + self.config.hidden_size, + gather_output=True, + bias=False, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fc", + ) + self.layers = nn.ModuleList( + ExaoneMoeDecoderLayer( + vllm_config.model_config.hf_config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{idx}", + mtp_layer=True, + ) + for idx in range(self.num_mtp_layers) + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_fc_norm_hidden = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_fc_norm_embedding = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + assert hidden_states.shape[-1] == inputs_embeds.shape[-1] + inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds) + hidden_states = self.pre_fc_norm_hidden(hidden_states) + hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1) + hidden_states = self.fc(hidden_states) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + current_step_idx = spec_step_idx % self.num_mtp_layers + hidden_states, residual = self.layers[current_step_idx]( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +@support_torch_compile +class ExaoneMoeMTP(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + cache_config = vllm_config.cache_config + assert not cache_config.enable_prefix_caching, ( + "ExaoneMoeMTP currently does not support prefix caching" + ) + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.model = ExaoneMoeMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp") + ) + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + # padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + spec_step_idx: int = 0, + **kwargs: object, + ): + hidden_states = self.model( + input_ids, + positions, + hidden_states, + intermediate_tensors, + inputs_embeds, + spec_step_idx, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor | None: + return self.logits_processor(self.lm_head, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + shared_weight_names = ["embed_tokens", "lm_head"] + + def remap_weight_names(weights): + for name, weight in weights: + if name.startswith("mtp."): + name = name.replace("mtp.", "model.") + elif not any(key in name for key in shared_weight_names): + continue + yield name, weight + + loader = AutoWeightsLoader(self) + return loader.load_weights(remap_weight_names(weights)) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index a8aca4e89..0b165232c 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -98,6 +98,7 @@ _TEXT_GENERATION_MODELS = { "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"), + "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"), "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), @@ -457,6 +458,7 @@ _SPECULATIVE_DECODING_MODELS = { "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), + "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"), "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"),