diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 2202a4b34..2141163df 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -418,6 +418,7 @@ th { | `Grok1ForCausalLM` | Grok2 | `xai-org/grok-2` | ✅︎ | ✅︎ | | `HunYuanDenseV1ForCausalLM` | Hunyuan Dense | `tencent/Hunyuan-7B-Instruct` | ✅︎ | ✅︎ | | `HunYuanMoEV1ForCausalLM` | Hunyuan-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ | +| `HyperCLOVAXForCausalLM` | HyperCLOVAX-SEED-Think-14B | `naver-hyperclovax/HyperCLOVAX-SEED-Think-14B` | ✅︎ | ✅︎ | | `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | | `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | | `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index ec8949b00..c52448083 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -103,6 +103,10 @@ AITER_MODEL_LIST = [ marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), pytest.param("swiss-ai/Apertus-8B-Instruct-2509"), # apertus + pytest.param( + "naver-hyperclovax/HyperCLOVAX-SEED-Think-14B", # hyperclovax + marks=[large_gpu_mark(min_gb=32)], + ), ], ) @pytest.mark.parametrize("max_tokens", [32]) diff --git a/tests/models/registry.py b/tests/models/registry.py index 81f9347dd..7f806064f 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -320,7 +320,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "tencent/Hunyuan-A13B-Instruct", trust_remote_code=True ), "HyperCLOVAXForCausalLM": _HfExamplesInfo( - "naver-hyperclovax/HyperCLOVAX-SEED-Think-32B", + "naver-hyperclovax/HyperCLOVAX-SEED-Think-14B", trust_remote_code=True, ), "InternLMForCausalLM": _HfExamplesInfo( diff --git a/vllm/model_executor/models/hyperclovax.py b/vllm/model_executor/models/hyperclovax.py new file mode 100644 index 000000000..3176c4284 --- /dev/null +++ b/vllm/model_executor/models/hyperclovax.py @@ -0,0 +1,551 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Copyright 2025 NAVER Cloud HyperCLOVA team + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2025 NAVER Cloud HyperCLOVA team. All rights reserved. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only HyperCLOVAX model compatible with HuggingFace weights.""" + +from collections.abc import Iterable +from itertools import islice + +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.hyperclovax import HyperCLOVAXConfig + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + + +class HyperCLOVAXMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + disable_tp: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + disable_tp=disable_tp, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=disable_tp, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class HyperCLOVAXAttention(nn.Module): + def __init__( + self, + config: HyperCLOVAXConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position_embeddings: int = 8192, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + cache_config: CacheConfig | None = None, + prefix: str = "", + dual_chunk_attention_config: dict | None = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.attention_multiplier + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + max_position=max_position_embeddings, + is_neox_style=True, + rope_parameters=getattr(config, "rope_parameters", None), + dual_chunk_attention_config=dual_chunk_attention_config, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class HyperCLOVAXDecoderLayer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.hidden_size = config.hidden_size + self.residual_multiplier = config.residual_multiplier + max_position_embeddings = getattr( + config, + "max_position_embeddings", + 8192, + ) + dual_chunk_attention_config = getattr( + config, + "dual_chunk_attention_config", + None, + ) + attention_bias = getattr(config, "attention_bias", False) + + self.self_attn = HyperCLOVAXAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.mlp = HyperCLOVAXMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + # post-norm (dual-norm) + self.use_post_norm = config.use_post_norm + if self.use_post_norm: + self.post_norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Unlike models that use a fused add-norm kernel (e.g. Llama), HyperCLOVAX + # applies the residual connection explicitly with a muP scaling factor + # (residual + hidden * residual_multiplier). As a result, each layer's + # hidden_states output already includes the residual addition, so the + # incoming residual is not needed and is reset at the start of each layer. + # The residual parameter is kept for interface consistency with other vllm + # decoder layers. + + # Self Attention + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + # Custom ln + if self.use_post_norm: + hidden_states = self.post_norm1(hidden_states) + + # The residual is added outside the layernorm function to apply muP. + hidden_states = residual + hidden_states * self.residual_multiplier # muP + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + # Custom ln + if self.use_post_norm: + hidden_states = self.post_norm2(hidden_states) + + # The residual is added outside the layernorm function to apply muP. + hidden_states = residual + hidden_states * self.residual_multiplier # muP + + return hidden_states, residual + + +@support_torch_compile +class HyperCLOVAXModel(nn.Module): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = HyperCLOVAXDecoderLayer, + ): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + self.embed_tokens: VocabParallelEmbedding | PPMissingLayer + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm: RMSNorm | PPMissingLayer + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + assert input_ids is not None + hidden_states = self.embed_input_ids(input_ids) + residual = None + + hidden_states *= self.config.embedding_multiplier # muP + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + assert residual is not None + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + # The residual is added outside the layernorm function to apply muP. + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + if "scale" in name or "zero_point" in name: + # Remapping the name of FP8 kv-scale or zero point. + remapped_name = maybe_remap_kv_scale_name(name, params_dict) + if remapped_name is None: + continue + name = remapped_name + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader # type: ignore[attr-defined] + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class HyperCLOVAXForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = HyperCLOVAXDecoderLayer, + ): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + + self.model = self._init_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + layer_type=layer_type, + ) + + self.lm_head: ParallelLMHead | PPMissingLayer + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + + logit_scale = getattr(config, "logit_scale", 1.0) + if hasattr(config, "logits_scaling"): + logit_scale *= config.logits_scaling # muP + self.logits_processor = LogitsProcessor( + config.vocab_size, + scale=logit_scale, + ) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( # type: ignore[method-assign] + self.model.make_empty_intermediate_tensors + ) + + def _init_model( + self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = HyperCLOVAXDecoderLayer, + ): + return HyperCLOVAXModel( + vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type, + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_tokens(input_ids) + + def forward( # type: ignore[override] + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + *, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["lm_head."] if self.config.tie_word_embeddings else None, + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index bef18dbd5..51f370bcc 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -133,7 +133,7 @@ _TEXT_GENERATION_MODELS = { "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"), "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"), "HCXVisionV2ForCausalLM": ("hyperclovax_vision_v2", "HCXVisionV2ForCausalLM"), - "HyperCLOVAXForCausalLM": ("llama", "LlamaForCausalLM"), + "HyperCLOVAXForCausalLM": ("hyperclovax", "HyperCLOVAXForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"), diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index a19a5ec0f..1d5aecd80 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -33,6 +33,7 @@ _CLASS_TO_MODULE: dict[str, str] = { "HunYuanVLConfig": "vllm.transformers_utils.configs.hunyuan_vl", "HunYuanVLTextConfig": "vllm.transformers_utils.configs.hunyuan_vl", "HunYuanVLVisionConfig": "vllm.transformers_utils.configs.hunyuan_vl", + "HyperCLOVAXConfig": "vllm.transformers_utils.configs.hyperclovax", "IsaacConfig": "vllm.transformers_utils.configs.isaac", # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the @@ -91,6 +92,7 @@ __all__ = [ "HunYuanVLConfig", "HunYuanVLTextConfig", "HunYuanVLVisionConfig", + "HyperCLOVAXConfig", "IsaacConfig", "RWConfig", "JAISConfig", diff --git a/vllm/transformers_utils/configs/hyperclovax.py b/vllm/transformers_utils/configs/hyperclovax.py new file mode 100644 index 000000000..9fa823743 --- /dev/null +++ b/vllm/transformers_utils/configs/hyperclovax.py @@ -0,0 +1,277 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Copyright 2025 NAVER Cloud HyperCLOVA team +# +# Copyright 2025 NAVER Cloud HyperCLOVA team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""HyperCLOVA X model configuration.""" + +from transformers.configuration_utils import PretrainedConfig + + +class HyperCLOVAXConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a + [`HyperCLOVAXModel`]. It is used to instantiate a HyperCLOVAX model + according to the specified arguments, defining the model architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from + [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the HyperCLOVAX model. Defines the number of + different tokens that can be represented by the `input_ids` + passed when calling [`HyperCLOVAXModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the + Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to + implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use + Multi Head Attention (MHA), if `num_key_value_heads=1` the model + will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each + group key and value head should be constructed by meanpooling all + the original heads within that group. For more details checkout + [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not + specified, will default to `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the + decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used + with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values + attentions (not used by all models). Only relevant if + `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during + pretraining. Please refer to [this document](https://huggingface. + co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) + to understand more about it. This value is necessary to ensure + exact reproducibility of the pretraining results. Please refer to + [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE + embeddings. NOTE: if you apply new rope type and you expect the + model to work on longer `max_position_embeddings`, we recommend + you to update this value accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', + 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with + 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling + factor to apply to the RoPE embeddings. In most scaling + types, a `factor` of x will enable the model to handle + sequences of length x * original maximum pre-trained + length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The + original max position embeddings used during pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be + applied on the attention computation. If unspecified, it + defaults to value recommended by the implementation, using + the `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for + extrapolation (only) in the linear ramp function. If + unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for + interpolation (only) in the linear ramp function. If + unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be + applied to short contexts (< + `original_max_position_embeddings`). Must be a list of + numbers with the same length as the hidden size divided + by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be + applied to long contexts (< + `original_max_position_embeddings`). Must be a list of + numbers with the same length as the hidden size divided + by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low + frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high + frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output + projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers + in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to + hidden_size // num_heads + embedding_multiplier (`float`, *optional*, defaults to `None`): + Multiplier applied to the embedding weights. If `None`, it is + equivalent to `1.0`. + logits_scaling (`float`, *optional*, defaults to `None`): + Scaling factor for logits. If `None`, it is equivalent to `1.0`. + attention_multiplier (`float`, *optional*, defaults to `None`): + Multiplier applied to the attention weights. If `None`, it is + equivalent to `self.head_dim ** -0.5`. + residual_multiplier (`float`, *optional*, defaults to `None`): + Scaling factor for residual connections. If `None`, it is + equivalent to `1.0`. + use_post_norm (`bool`, *optional*, defaults to `True`): + Determines whether to apply Peri-Layer Normalization. Set to + False to disable this feature. + rope_parameters (`dict`, *optional*): + Dictionary containing the RoPE parameters used by vLLM's + `get_rope`. When provided, takes precedence over `rope_theta` + and `rope_scaling`. If `None`, it is derived from `rope_theta` + and `rope_scaling` automatically. + """ + + model_type = "hyperclovax" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + head_dim=None, + embedding_multiplier=None, # mup + logits_scaling=None, # mup + attention_multiplier=None, # mup + residual_multiplier=None, # mup + use_post_norm=True, # post-norm(peri-LN) + rope_parameters=None, + auto_map=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.head_dim = ( + head_dim + if head_dim is not None + else self.hidden_size // self.num_attention_heads + ) + # Derive rope_parameters for vLLM's get_rope() from rope_theta / + # rope_scaling, unless the caller already provided rope_parameters. + if rope_parameters is None: + if rope_scaling is not None: + # Shallow-copy to avoid mutating the caller's dict. + rope_parameters = dict(rope_scaling) + # BC: 'type' field -> 'rope_type', remove stale key. + if "type" in rope_parameters: + rope_parameters.setdefault("rope_type", rope_parameters.pop("type")) + else: + rope_parameters = {"rope_type": "default"} + if "rope_theta" not in rope_parameters: + rope_parameters["rope_theta"] = rope_theta + self.rope_parameters = rope_parameters + + # BC: keep self.rope_scaling consistent for HF serialization. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + + # mup + self.embedding_multiplier = ( + embedding_multiplier if embedding_multiplier is not None else 1.0 + ) + self.logits_scaling = logits_scaling if logits_scaling is not None else 1.0 + self.attention_multiplier = ( + attention_multiplier + if attention_multiplier is not None + else self.head_dim**-0.5 + ) + self.residual_multiplier = ( + residual_multiplier if residual_multiplier is not None else 1.0 + ) + + # post-norm (Peri-LN) + self.use_post_norm = use_post_norm + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + auto_map=auto_map, + **kwargs, + )