diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 98d2a08d9..967f3cfb6 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -448,6 +448,7 @@ th { | `OlmoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | | `Olmo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | | `Olmo3ForCausalLM` | OLMo3 | `allenai/Olmo-3-7B-Instruct`, `allenai/Olmo-3-32B-Think`, etc. | ✅︎ | ✅︎ | +| `OlmoHybridForCausalLM` | OLMo Hybrid | `allenai/Olmo-Hybrid-7B` | ✅︎ | ✅︎ | | `OlmoeForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | | `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | ✅︎ | ✅︎ | | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 3c9bb77e7..4a105dedd 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -420,6 +420,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), "Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"), "Olmo3ForCausalLM": _HfExamplesInfo("allenai/Olmo-3-7B-Instruct"), + "OlmoHybridForCausalLM": _HfExamplesInfo("allenai/Olmo-Hybrid-7B"), "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), "OPTForCausalLM": _HfExamplesInfo( "facebook/opt-125m", {"1b": "facebook/opt-iml-max-1.3b"} diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c46460959..59af0109b 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -666,6 +666,7 @@ class CompilationConfig: "vllm::linear_attention", "vllm::plamo2_mamba_mixer", "vllm::gdn_attention_core", + "vllm::olmo_hybrid_gdn_full_forward", "vllm::kda_attention", "vllm::sparse_attn_indexer", "vllm::rocm_aiter_sparse_attn_indexer", diff --git a/vllm/model_executor/layers/fla/ops/l2norm.py b/vllm/model_executor/layers/fla/ops/l2norm.py index 4d7dbb510..2eb137a24 100644 --- a/vllm/model_executor/layers/fla/ops/l2norm.py +++ b/vllm/model_executor/layers/fla/ops/l2norm.py @@ -76,16 +76,20 @@ def l2norm_fwd_kernel( @triton.jit -def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): +def l2norm_fwd_kernel2( + X, Y, eps, M, N: tl.constexpr, BD: tl.constexpr, MBLOCK: tl.constexpr +): xoffset = tl.program_id(0) * MBLOCK row_idx = xoffset + tl.arange(0, MBLOCK)[:, None] xmask = row_idx < M - rindex = tl.arange(0, N)[None, :] - xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) - square = tl.broadcast_to(xs * xs, [MBLOCK, N]) + rindex = tl.arange(0, BD)[None, :] + cmask = rindex < N + mask = xmask & cmask + xs = tl.load(X + (rindex + N * row_idx), mask, other=0.0).to(tl.float32) + square = tl.broadcast_to(xs * xs, [MBLOCK, BD]) square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] rsqrt = tl.rsqrt(square_sum + eps) - tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) + tl.store(Y + (rindex + N * row_idx), xs * rsqrt, mask) def l2norm_fwd( @@ -116,6 +120,7 @@ def l2norm_fwd( eps, T, D, + BD, MBLOCK, ) else: diff --git a/vllm/model_executor/layers/fla/ops/layernorm_guard.py b/vllm/model_executor/layers/fla/ops/layernorm_guard.py index 3abfbff9e..8b9e27573 100644 --- a/vllm/model_executor/layers/fla/ops/layernorm_guard.py +++ b/vllm/model_executor/layers/fla/ops/layernorm_guard.py @@ -250,57 +250,55 @@ def layer_norm_fwd( return out, mean, rstd -class LayerNormFn(torch.autograd.Function): - @input_guard - @staticmethod - def forward( - ctx, +def _layer_norm_fn_impl( + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, + activation: str = "swish", +): + """Triton layer/RMS norm with optional gating. + + If z is not None, computes norm(x) * silu(z) when norm_before_gate, + else norm(x * silu(z)). + + This calls the triton kernel directly. The original code wrapped this + in a torch.autograd.Function (LayerNormFn) to save tensors for a + backward pass, but vLLM is inference-only so there is no backward pass. + The autograd wrapper also prevented torch.compile/dynamo from tracing + through the function due to its @staticmethod forward. + """ + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, _, _ = layer_norm_fwd( x, weight, bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True, - is_rms_norm=False, - activation: str = "swish", - ): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" - - x_shape_og = x.shape - # reshape input data into 2D tensor - x = x.reshape(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() - if z is not None: - assert z.shape == x_shape_og - z = z.reshape(-1, z.shape[-1]) - if z.stride(-1) != 1: - z = z.contiguous() - weight = weight.contiguous() - if bias is not None: - bias = bias.contiguous() - y, mean, rstd = layer_norm_fwd( - x, - weight, - bias, - eps, - z=z, - group_size=group_size, - norm_before_gate=norm_before_gate, - is_rms_norm=is_rms_norm, - activation=activation, - ) - ctx.save_for_backward(x, weight, bias, mean, rstd, z) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.group_size = group_size - ctx.norm_before_gate = norm_before_gate - ctx.is_rms_norm = is_rms_norm - ctx.activation = activation - return y.reshape(x_shape_og) + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + activation=activation, + ) + return y.reshape(x_shape_og) +@input_guard def layernorm_fn( x, weight, @@ -312,11 +310,12 @@ def layernorm_fn( is_rms_norm=False, activation: str = "swish", ): - return LayerNormFn.apply( + return _layer_norm_fn_impl( x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm, activation ) +@input_guard def rmsnorm_fn( x, weight, @@ -327,7 +326,7 @@ def rmsnorm_fn( norm_before_gate=True, activation: str = "swish", ): - return LayerNormFn.apply( + return _layer_norm_fn_impl( x, weight, bias, z, eps, group_size, norm_before_gate, True, activation ) diff --git a/vllm/model_executor/models/olmo_hybrid.py b/vllm/model_executor/models/olmo_hybrid.py new file mode 100644 index 000000000..a94f8c875 --- /dev/null +++ b/vllm/model_executor/models/olmo_hybrid.py @@ -0,0 +1,1172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from: +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py +# Copyright 2026 The vLLM team. +# +# This code combines OLMo2/OLMo3 attention with Gated DeltaNet linear attention +# for the OLMo Hybrid architecture. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMo Hybrid model compatible with HuggingFace weights.""" + +from collections.abc import Iterable +from functools import partial +from itertools import islice + +import torch +from einops import rearrange +from torch import nn +from transformers.activations import ACT2FN + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import ( + CacheConfig, + ModelConfig, + SpeculativeConfig, + VllmConfig, + get_current_vllm_config, +) +from vllm.distributed import ( + divide, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) +from vllm.distributed.utils import split_tensor_along_last_dim +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.fla.ops import ( + chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule, +) +from vllm.model_executor.layers.layernorm import RMSNorm, RMSNormGated +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, + causal_conv1d_update, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + sharded_weight_loader, +) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.triton_utils import tl, triton +from vllm.triton_utils.allocation import set_triton_allocator +from vllm.utils.torch_utils import direct_register_custom_op +from vllm.v1.attention.backend import AttentionMetadata +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata + +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + +logger = init_logger(__name__) + + +def _make_fused_conv1d_weight_loader(dims, tp_size, tp_rank): + """Weight loader for loading separate HF conv weights into a fused conv1d. + + dims: list of original (un-sharded) dims per section, + e.g. [key_dim, key_dim, value_dim] + """ + sharded_dims = [d // tp_size for d in dims] + + def weight_loader(param, loaded_weight, loaded_shard_id=None): + if loaded_weight.dim() == 2: + loaded_weight = loaded_weight.unsqueeze(1) + dim = dims[loaded_shard_id] + shard_size = dim // tp_size + tp_start = tp_rank * shard_size + sharded_weight = loaded_weight[tp_start : tp_start + shard_size] + offset = sum(sharded_dims[:loaded_shard_id]) + param.data[offset : offset + shard_size].copy_(sharded_weight) + + return weight_loader + + +class OlmoHybridGatedDeltaNet(nn.Module, MambaBase): + """ + Gated DeltaNet linear attention layer for OLMo Hybrid. + + This implements the linear attention mechanism that replaces sliding window + attention in the hybrid architecture. + """ + + @property + def mamba_type(self) -> str: + return "gdn_attention" + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + self.cache_config.mamba_ssm_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.gated_delta_net_state_shape( + self.tp_size, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + self.conv_kernel_size, + self.num_spec, + ) + + def __init__( + self, + config, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + speculative_config: SpeculativeConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = extract_layer_index(prefix) + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + assert getattr(config, "linear_use_gate", True), ( + "OlmoHybridGatedDeltaNet requires linear_use_gate=True" + ) + self.allow_neg_eigval = getattr(config, "linear_allow_neg_eigval", False) + self.prefix = prefix + + self.config = config + self.model_config = model_config + self.cache_config = cache_config + self.quant_config = quant_config + self.speculative_config = speculative_config + self.num_spec = ( + self.speculative_config.num_speculative_tokens + if self.speculative_config + else 0 + ) + + # Fused QKVG projection: 1 matmul instead of 4 + self.in_proj_qkvg = MergedColumnParallelLinear( + input_size=self.hidden_size, + output_sizes=[self.key_dim, self.key_dim, self.value_dim, self.value_dim], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_qkvg", + ) + + # Separate B and A projections to preserve numerical precision. + # Fusing these into one matmul changes FP accumulation order for the + # gating scalars, which compounds through the GDN recurrent state. + self.b_proj = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.num_v_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.b_proj", + ) + self.a_proj = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.num_v_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.a_proj", + ) + + # Fused conv1d: single parameter instead of 3 + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.conv_dim, + bias=False, + prefix=f"{prefix}.conv1d", + ) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, + { + "weight_loader": _make_fused_conv1d_weight_loader( + [self.key_dim, self.key_dim, self.value_dim], + self.tp_size, + self.tp_rank, + ) + }, + ) + + self.dt_bias = nn.Parameter( + torch.ones(self.num_v_heads // self.tp_size), + ) + self.A_log = nn.Parameter( + torch.empty( + divide(self.num_v_heads, self.tp_size), + ) + ) + + set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) + + # use eps=1e-5 to match FLA's FusedRMSNormGated + self.o_norm = RMSNormGated( + self.head_v_dim, + eps=1e-5, + group_size=None, + norm_before_gate=True, + device=current_platform.current_device(), + dtype=config.torch_dtype if hasattr(config, "torch_dtype") else None, + ) + + self.o_proj = RowParallelLinear( + self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # FLA triton kernels need a PyTorch-backed allocator for scratch + # memory (required by triton >= 3.x autotuner). Set once at init. + set_triton_allocator(current_platform.current_device()) + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def rearrange_mixed_qkv(self, mixed_qkv): + if mixed_qkv is None: + return None, None, None + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim // self.tp_size, + self.key_dim // self.tp_size, + self.value_dim // self.tp_size, + ], + dim=-1, + ) + + num_k_heads = self.num_k_heads // self.tp_size + num_v_heads = self.num_v_heads // self.tp_size + + query = rearrange(query, "l (h d) -> 1 l h d", h=num_k_heads, d=self.head_k_dim) + key = rearrange(key, "l (h d) -> 1 l h d", h=num_k_heads, d=self.head_k_dim) + value = rearrange(value, "l (h d) -> 1 l h d", h=num_v_heads, d=self.head_v_dim) + + # GQA expansion if needed + if num_v_heads > num_k_heads: + expand_ratio = num_v_heads // num_k_heads + query = query.unsqueeze(3).expand(-1, -1, -1, expand_ratio, -1) + query = query.reshape(1, query.shape[1], num_v_heads, self.head_k_dim) + key = key.unsqueeze(3).expand(-1, -1, -1, expand_ratio, -1) + key = key.reshape(1, key.shape[1], num_v_heads, self.head_k_dim) + + return query.contiguous(), key.contiguous(), value.contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + # NOTE: We wrap the ENTIRE linear attention forward (projections + + # core recurrence + output norm + output projection) in a single + # custom op, rather than just wrapping the recurrent core like + # other GDN models (e.g. Qwen3Next) do. + # + # Why: torch.compile with inductor generates fused kernels for + # matmuls and pointwise ops. These fused kernels can differ in + # floating-point accumulation order from eager-mode cuBLAS, + # introducing small numerical differences (~1e-7 per op). For + # standard transformer attention this is harmless because each + # position is computed independently. But for the GDN recurrent + # state, these tiny input differences compound at every timestep + # across the full sequence length, causing severe logprob + # divergence (e.g. ~15% top-1 agreement with eager baseline). + # + # By making the full forward opaque to inductor, the projections + # and output norm run with eager-mode kernels (cuBLAS, triton), + # preserving numerical consistency. The tradeoff is reduced + # compilation speedup (~1.5x vs ~3x), but logprob agreement + # improves from ~15% to ~83% top-1 vs eager. + # + # The remaining ~17% divergence comes from inductor compiling + # the MLP and transformer attention layers that are NOT wrapped + # in custom ops -- their small precision differences propagate + # as inputs to the GDN layers from outside. + torch.ops.vllm.olmo_hybrid_gdn_full_forward( + hidden_states, + output, + self.prefix, + ) + + def _full_forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + num_tokens = hidden_states.size(0) + + # ============================================================ + # Part 1: Input Projection (2 fused matmuls instead of 6) + # ============================================================ + projected_qkvg, _ = self.in_proj_qkvg(hidden_states) + conv_dim_sharded = (self.key_dim * 2 + self.value_dim) // self.tp_size + mixed_qkv = projected_qkvg[..., :conv_dim_sharded] + gate = projected_qkvg[..., conv_dim_sharded:] + + b, _ = self.b_proj(hidden_states) + a, _ = self.a_proj(hidden_states) + + # ============================================================ + # Part 2: Core Attention + # ============================================================ + core_attn_out = torch.zeros( + (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + self._forward_core( + mixed_qkv=mixed_qkv, + b=b, + a=a, + core_attn_out=core_attn_out, + ) + + # ============================================================ + # Part 3: Output Projection + # ============================================================ + gate = gate.view(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim) + core_attn_out_flat = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + gate_flat = gate.reshape(-1, gate.shape[-1]) + core_attn_out_normed = self.o_norm(core_attn_out_flat, gate_flat) + core_attn_out = core_attn_out_normed.view( + num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim + ) + + core_attn_out = rearrange(core_attn_out, "l h d -> l (h d)") + output[:num_tokens], _ = self.o_proj(core_attn_out) + + def _forward_core( + self, + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + ): + """ + Core attention computation (called by custom op). + """ + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_indx = attn_metadata.spec_token_indx + non_spec_token_indx = attn_metadata.non_spec_token_indx + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + num_actual_tokens = attn_metadata.num_actual_tokens + num_accepted_tokens = attn_metadata.num_accepted_tokens + + mixed_qkv = mixed_qkv[:num_actual_tokens] + b = b[:num_actual_tokens] + a = a[:num_actual_tokens] + + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) + + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) + mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx) + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + if spec_sequence_masks is not None: + mixed_qkv_spec = causal_conv1d_update( + mixed_qkv_spec, + conv_state, + conv_weights, + None, # no bias + self.activation, + conv_state_indices=spec_state_indices_tensor[:, 0][ + : attn_metadata.num_spec_decodes + ], + num_accepted_tokens=num_accepted_tokens, + query_start_loc=spec_query_start_loc, + max_query_len=spec_state_indices_tensor.size(-1), + validate_data=False, + ) + + if attn_metadata.num_prefills > 0: + mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) + mixed_qkv_non_spec = causal_conv1d_fn( + mixed_qkv_non_spec_T, + conv_weights, + None, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + metadata=attn_metadata, + ).transpose(0, 1) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec, + conv_state, + conv_weights, + None, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[ + : attn_metadata.num_decodes + ], + validate_data=True, + ) + else: + mixed_qkv_non_spec = None + + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( + mixed_qkv_non_spec + ) + + g, beta = fused_olmo_hybrid_gdn_gating( + self.A_log, a, b, self.dt_bias, self.allow_neg_eigval + ) + + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g.index_select(1, spec_token_indx) + beta_spec = beta.index_select(1, spec_token_indx) + g_non_spec = g.index_select(1, non_spec_token_indx) + beta_non_spec = beta.index_select(1, non_spec_token_indx) + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out_spec, last_recurrent_state = None, None + + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + use_qk_l2norm_in_kernel=True, + ) + ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( + ssm_state.dtype + ) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[ + : attn_metadata.num_decodes + 1 + ], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + ) + ) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + + if spec_sequence_masks is not None and core_attn_out_non_spec is not None: + merged_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) + merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) + elif spec_sequence_masks is not None: + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) + + +class OlmoHybridAttention(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + + hidden_size = self.config.hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = self.config.num_attention_heads + + assert hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % self.tp_size == 0 + + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = ( + self.config.num_key_value_heads or self.total_num_heads + ) + if self.total_num_kv_heads >= self.tp_size: + assert self.total_num_kv_heads % self.tp_size == 0 + else: + assert self.tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.max_position_embeddings = self.config.max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.tp_rank = get_tensor_model_parallel_rank() + + self.k_norm = RMSNorm( + self.total_num_kv_heads * self.head_dim, + eps=self.config.rms_norm_eps, + ) + self.q_norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + + self.scaling = self.head_dim**-0.5 + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.attn", + ) + + rope_parameters = getattr(self.config, "rope_parameters", None) + self._use_rope = (rope_parameters is not None) and ( + rope_parameters["rope_theta"] is not None + ) + + if self._use_rope: + self.rotary_emb = get_rope( + self.head_dim, + max_position=self.max_position_embeddings, + rope_parameters=rope_parameters, + ) + else: + self.rotary_emb = None + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.o_proj", + ) + + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm(q) + k = self.k_norm(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + if self._use_rope: + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class OlmoHybridMLP(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + + self.act_fn = SiluAndMul() + + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.down_proj", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class OlmoHybridDecoderLayer(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + speculative_config = vllm_config.speculative_config + + layer_idx = extract_layer_index(prefix) + self.layer_type = config.layer_types[layer_idx] + self.layer_idx = layer_idx + + if self.layer_type == "linear_attention": + self.linear_attn = OlmoHybridGatedDeltaNet( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=f"{prefix}.linear_attn", + ) + self.input_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + else: + self.self_attn = OlmoHybridAttention( + vllm_config=vllm_config, + prefix=f"{prefix}.self_attn", + ) + # Attention layers use these norm names + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + + self.mlp = OlmoHybridMLP( + vllm_config=vllm_config, + prefix=f"{prefix}.mlp", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + if self.layer_type == "linear_attention": + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + attn_output = torch.empty_like(hidden_states) + self.linear_attn( + hidden_states=hidden_states, + output=attn_output, + ) + hidden_states = residual + attn_output + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + else: + residual = hidden_states + hidden_states = self.self_attn(positions, hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@support_torch_compile +class OlmoHybridModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=f"{prefix}.embed_tokens", + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + self.config.num_hidden_layers, + lambda prefix: OlmoHybridDecoderLayer( + vllm_config=vllm_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) + + self.norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], self.config.hidden_size + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + assert isinstance(hidden_states, torch.Tensor) + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states = layer(positions, hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + linear_attn_stacked_params_mapping = [ + ("in_proj_qkvg", "q_proj", 0), + ("in_proj_qkvg", "k_proj", 1), + ("in_proj_qkvg", "v_proj", 2), + ("in_proj_qkvg", "g_proj", 3), + ("conv1d", "q_conv1d", 0), + ("conv1d", "k_conv1d", 1), + ("conv1d", "v_conv1d", 2), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if is_pp_missing_parameter(name, self): + continue + + handled = False + + if "linear_attn" in name: + for ( + param_name, + weight_name, + shard_id, + ) in linear_attn_stacked_params_mapping: + if weight_name not in name: + continue + mapped_name = name.replace(weight_name, param_name) + if mapped_name.endswith(".bias") and ( + mapped_name not in params_dict + ): + continue + if mapped_name not in params_dict: + continue + param = params_dict[mapped_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + name = mapped_name + handled = True + break + else: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + handled = True + break + + if not handled: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class OlmoHybridForCausalLM( + nn.Module, HasInnerState, SupportsPP, SupportsLoRA, IsHybrid +): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + "in_proj_qkvg": ["q_proj", "k_proj", "v_proj", "g_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + + self.model = OlmoHybridModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=vllm_config.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, vllm_config: "VllmConfig" + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + tp_size = parallel_config.tensor_parallel_size + num_spec = ( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config + else 0 + ) + return MambaStateShapeCalculator.gated_delta_net_state_shape( + tp_size, + hf_config.linear_num_key_heads, + hf_config.linear_num_value_heads, + hf_config.linear_key_head_dim, + hf_config.linear_value_head_dim, + hf_config.linear_conv_kernel_dim, + num_spec, + ) + + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=( + ["lm_head.weight"] if self.config.tie_word_embeddings else None + ), + ) + return loader.load_weights(weights) + + +def olmo_hybrid_gdn_full_forward( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + """Full linear attention forward wrapped as a custom op. + + Prevents inductor from compiling the projections around the GDN core, + which would introduce numerical divergence that compounds through + the recurrent state. + """ + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._full_forward( + hidden_states=hidden_states, + output=output, + ) + + +def olmo_hybrid_gdn_full_forward_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + """Fake implementation for torch.compile.""" + return + + +direct_register_custom_op( + op_name="olmo_hybrid_gdn_full_forward", + op_func=olmo_hybrid_gdn_full_forward, + mutates_args=["output"], + fake_impl=olmo_hybrid_gdn_full_forward_fake, +) + + +@triton.jit +def fused_olmo_hybrid_gdn_gating_kernel( + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + allow_neg_eigval: tl.constexpr, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_b = tl.load(b + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + + # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where( + beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x + ) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + + # beta = self.b_proj(hidden_states).sigmoid() + # if self.allow_neg_eigval: beta = beta * 2.0 + blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) + if allow_neg_eigval: + blk_beta_output = blk_beta_output * 2.0 + tl.store( + beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask + ) + + +def fused_olmo_hybrid_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + allow_neg_eigval: bool = False, + beta: float = 1.0, + threshold: float = 20.0, +) -> tuple[torch.Tensor, torch.Tensor]: + batch, num_heads = a.shape + seq_len = 1 + grid = (batch, seq_len, triton.cdiv(num_heads, 8)) + g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) + beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device) + fused_olmo_hybrid_gdn_gating_kernel[grid]( + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + allow_neg_eigval, + num_heads, + beta, + threshold, + 8, + num_warps=1, + ) + return g, beta_output diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 1e5accaf3..274b18f35 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -171,6 +171,7 @@ _TEXT_GENERATION_MODELS = { "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), + "OlmoHybridForCausalLM": ("olmo_hybrid", "OlmoHybridForCausalLM"), "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 00129d52e..3d379de8b 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -97,6 +97,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( speculators="SpeculatorsConfig", nemotron="NemotronConfig", olmo3="Olmo3Config", + olmo_hybrid="OlmoHybridConfig", ovis="OvisConfig", ultravox="UltravoxConfig", step3_vl="Step3VLConfig", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 8b5d08b8a..7902515e2 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -49,6 +49,7 @@ _CLASS_TO_MODULE: dict[str, str] = { "NemotronConfig": "vllm.transformers_utils.configs.nemotron", "NemotronHConfig": "vllm.transformers_utils.configs.nemotron_h", "Olmo3Config": "vllm.transformers_utils.configs.olmo3", + "OlmoHybridConfig": "vllm.transformers_utils.configs.olmo_hybrid", "OvisConfig": "vllm.transformers_utils.configs.ovis", "PixelShuffleSiglip2VisionConfig": "vllm.transformers_utils.configs.isaac", "RadioConfig": "vllm.transformers_utils.configs.radio", @@ -102,6 +103,7 @@ __all__ = [ "NemotronConfig", "NemotronHConfig", "Olmo3Config", + "OlmoHybridConfig", "OvisConfig", "PixelShuffleSiglip2VisionConfig", "RadioConfig", diff --git a/vllm/transformers_utils/configs/olmo_hybrid.py b/vllm/transformers_utils/configs/olmo_hybrid.py new file mode 100644 index 000000000..1087124c7 --- /dev/null +++ b/vllm/transformers_utils/configs/olmo_hybrid.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from transformers.configuration_utils import PretrainedConfig, layer_type_validation + + +class OlmoHybridConfig(PretrainedConfig): + r""" + Configuration class for [`OlmoHybridModel`]. It is used to + instantiate an OLMo Hybrid model according to the specified + arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar + configuration to that of the + [allenai/Olmo-Hybrid-7B](https://huggingface.co/allenai/Olmo-Hybrid-7B) + model. + + Configuration objects inherit from [`PreTrainedConfig`] and + can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 100352): + Vocabulary size of the OlmoHybrid model. Defines + the number of different tokens that can be + represented by the `inputs_ids` passed when + calling [`OlmoHybridModel`]. + hidden_size (`int`, *optional*, defaults to 3840): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, + defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, + defaults to 32): + Number of hidden layers in the Transformer + decoder. + num_attention_heads (`int`, *optional*, + defaults to 30): + Number of attention heads for each attention + layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that + should be used to implement Grouped Query + Attention. If + `num_key_value_heads=num_attention_heads`, + the model will use Multi Head Attention (MHA), + if `num_key_value_heads=1` the model will use + Multi Query Attention (MQA) otherwise GQA is + used. When converting a multi-head checkpoint + to a GQA checkpoint, each group key and value + head should be constructed by meanpooling all + the original heads within that group. For more + details, check out + [this paper](https://huggingface.co/papers/2305.13245). + If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, + defaults to `"silu"`): + The non-linear activation function (function + or string) in the decoder. + max_position_embeddings (`int`, *optional*, + defaults to 65536): + The maximum sequence length that this model + might ever be used with. + initializer_range (`float`, *optional*, + defaults to 0.02): + The standard deviation of the + truncated_normal_initializer for initializing + all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last + key/values attentions (not used by all models). + Only relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, + defaults to 100277): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, + defaults to 100257): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, + defaults to `False`): + Whether to tie weight embeddings. + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration + parameters for the RoPE embeddings. Can be + `None` to disable RoPE. + attention_bias (`bool`, *optional*, + defaults to `False`): + Whether to use a bias in the query, key, value + and output projection layers during + self-attention. + attention_dropout (`float`, *optional*, + defaults to 0.0): + The dropout ratio for the attention + probabilities. + rms_norm_eps (`float`, *optional*, + defaults to 1e-06): + The epsilon used by the rms normalization + layers. + layer_types (`list`, *optional*): + Attention pattern for each layer. Can contain + `"full_attention"` or `"linear_attention"`. + Defaults to linear attention for most layers + with full attention for every 4th layer. + linear_num_key_heads (`int`, *optional*): + Number of key heads for the linear attention + layers. Defaults to `num_attention_heads`. + linear_num_value_heads (`int`, *optional*): + Number of value heads for the linear attention + layers. Defaults to `num_attention_heads`. + linear_key_head_dim (`int`, *optional*): + Dimension of each key head in linear attention + layers. Defaults to + `0.75 * hidden_size / linear_num_key_heads`. + linear_value_head_dim (`int`, *optional*): + Dimension of each value head in linear + attention layers. Defaults to + `2 * linear_key_head_dim`. + linear_a_log_min (`float`, *optional*, + defaults to 0.0): + Minimum value for uniform initialization of + A_log in GatedDeltaNet layers. + linear_a_log_max (`float`, *optional*, + defaults to 16.0): + Maximum value for uniform initialization of + A_log in GatedDeltaNet layers. + linear_dt_min (`float`, *optional*, + defaults to 0.001): + Minimum value for dt initialization in + GatedDeltaNet layers. + linear_dt_max (`float`, *optional*, + defaults to 0.1): + Maximum value for dt initialization in + GatedDeltaNet layers. + linear_dt_init_floor (`float`, *optional*, + defaults to 0.0001): + Floor value for clamping dt during + initialization in GatedDeltaNet layers. + linear_conv_kernel_dim (`int`, *optional*, + defaults to 4): + Kernel size for the short convolution applied + to queries, keys, and values in linear + attention layers. + linear_allow_neg_eigval (`bool`, *optional*, + defaults to `True`): + Whether to allow negative eigenvalues in the + GatedDeltaNet recurrence. When `True`, the + beta parameter is scaled by 2.0 to allow + values in range [0, 2] instead of [0, 1]. + ```python + >>> from transformers import ( + ... OlmoHybridModel, + ... OlmoHybridConfig, + ... ) + + >>> configuration = OlmoHybridConfig() + >>> model = OlmoHybridModel(configuration) + >>> configuration = model.config + ``` + """ + + model_type = "olmo_hybrid" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise_gather_output", + "layers.*.self_attn.k_proj": "colwise_gather_output", + "layers.*.self_attn.v_proj": "colwise_gather_output", + "layers.*.self_attn.o_proj": "rowwise_split_input", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size: int | None = 100352, + hidden_size: int | None = 3840, + intermediate_size: int | None = 11008, + num_hidden_layers: int | None = 32, + num_attention_heads: int | None = 30, + num_key_value_heads: int | None = None, + hidden_act: str | None = "silu", + max_position_embeddings: int | None = 65536, + initializer_range: float | None = 0.02, + use_cache: bool | None = True, + pad_token_id: int | None = 100277, + bos_token_id: int | None = None, + eos_token_id: int | None = 100257, + tie_word_embeddings: bool | None = False, + rope_parameters=None, + attention_bias: bool | None = False, + attention_dropout: float | None = 0.0, + rms_norm_eps: float | None = 1e-06, + layer_types: list[str] | None = None, + linear_num_key_heads: int | None = None, + linear_num_value_heads: int | None = None, + linear_key_head_dim: int | None = None, + linear_value_head_dim: int | None = None, + linear_a_log_min: float = 0.0, + linear_a_log_max: float = 16.0, + linear_dt_min: float = 0.001, + linear_dt_max: float = 0.1, + linear_dt_init_floor: float = 1e-4, + linear_conv_kernel_dim: int = 4, + linear_allow_neg_eigval: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + assert num_hidden_layers is not None + assert hidden_size is not None + assert num_attention_heads is not None + + if layer_types is None: + # Default: linear attention for most layers, full attention every 4th layer + layer_types = ["linear_attention"] * int(num_hidden_layers) + for i in range(int(num_hidden_layers)): + if i % 4 == 3: + layer_types[i] = "full_attention" + # Ensure at least one full attention layer for small num_hidden_layers + if "full_attention" not in layer_types: + layer_types[-1] = "full_attention" + + layer_type_validation(layer_types, num_hidden_layers) + if "linear_attention" not in layer_types: + raise ValueError( + "OLMoHybrid expects at least one 'linear_attention' layer." + ) + if all(t == "linear_attention" for t in layer_types): + raise ValueError("OLMoHybrid expects at least one attention layer.") + + self.layer_types = layer_types + + if linear_num_key_heads is None: + linear_num_key_heads = num_attention_heads + if linear_num_value_heads is None: + linear_num_value_heads = num_attention_heads + if linear_key_head_dim is None: + linear_key_head_dim = int(0.75 * hidden_size / linear_num_key_heads) + if linear_value_head_dim is None: + linear_value_head_dim = 2 * linear_key_head_dim + + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_a_log_min = linear_a_log_min + self.linear_a_log_max = linear_a_log_max + self.linear_dt_min = linear_dt_min + self.linear_dt_max = linear_dt_max + self.linear_dt_init_floor = linear_dt_init_floor + self.linear_conv_kernel_dim = linear_conv_kernel_dim + self.linear_allow_neg_eigval = linear_allow_neg_eigval + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_parameters = rope_parameters + + self.tie_word_embeddings = tie_word_embeddings + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id