# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable from typing import Optional, Union import torch import torch.nn as nn from transformers.activations import ACT2FN import vllm.envs as envs from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.selector import _Backend from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, SupportsV0Only) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .utils import make_layers, maybe_prefix logger = init_logger(__name__) class SwiGLUActivation(nn.Module): def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return x1 * nn.functional.silu(x2) class SambaYMLP(nn.Module): """Gated Linear Unit. Reference: Language Modeling with Gated Convolutional Networks. https://arxiv.org/pdf/1612.08083v3.pdf. """ def __init__(self, config): super().__init__() self.config = config self.fc1 = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) self.activation_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states): y = self.fc1(hidden_states) gate, y = y.chunk(2, dim=-1) y = y * self.activation_fn(gate) return self.fc2(y) def get_virtual_engine(): forward_context: ForwardContext = get_forward_context() return forward_context.virtual_engine class SambaYAttention(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None, yoco_cross: bool = False, cache_config: Optional[CacheConfig] = None, prefix: str = ""): super().__init__() if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing " "a `layer_idx` is not recommended and will lead to errors " "during the forward call if caching is used. Please make " "sure to provide a `layer_idx` when creating this class.") self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.yoco_cross = yoco_cross if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError("hidden_size must be divisible by num_heads " f"(got `hidden_size`: {self.hidden_size} and " f"`num_heads`: {self.num_heads}).") op_size = self.num_heads * self.head_dim + 2 * ( self.num_key_value_heads * self.head_dim) self.out_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) if yoco_cross: self.Wqkv = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) else: self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True) # disable sliding window for the second half of the model sliding_window = config.interleaved_sliding_window[layer_idx] if layer_idx >= config.num_hidden_layers // 2: assert sliding_window is None, \ "sliding_window must be none for the second decoder" else: assert sliding_window is not None, \ "sliding_window must be set for the first decoder" assert self.num_heads % 2 == 0, 'num_heads should be even' assert self.num_key_value_heads % 2 == 0, 'num_heads should be even' self.lambda_init = self.lambda_init_fn(layer_idx) self.lambda_q1 = nn.Parameter( torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) self.lambda_k1 = nn.Parameter( torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) self.lambda_q2 = nn.Parameter( torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) self.lambda_k2 = nn.Parameter( torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) self.subln = nn.RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True) params = { 'differential_flash_attention_config': { 'lambda_init': self.lambda_init, 'lambda_q1': self.lambda_q1, 'lambda_k1': self.lambda_k1, 'lambda_q2': self.lambda_q2, 'lambda_k2': self.lambda_k2, "subln": self.subln, } } if yoco_cross: kv_shared_layer_index = config.num_hidden_layers // 2 + 1 kv_sharing_target_layer_name = \ f"model.layers.{kv_shared_layer_index}.self_attn.attn" else: kv_sharing_target_layer_name = None self.attn = Attention( self.num_heads, self.head_dim, self.head_dim**-0.5, num_kv_heads=self.num_key_value_heads, cache_config=cache_config, per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", attn_type=AttentionType.DECODER, kv_sharing_target_layer_name=kv_sharing_target_layer_name, **params) assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\ "DIFFERENTIAL_FLASH_ATTN required" def lambda_init_fn(self, depth): return 0.8 - 0.6 * math.exp(-0.3 * depth) def forward( self, hidden_states: torch.Tensor, ): if not self.yoco_cross: # need to generate kv-cache qkv = self.Wqkv(hidden_states) q, k, v = qkv.split([ self.hidden_size, self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim ], dim=-1) attn_output = self.attn(q, k, v) else: # re-use the kv cache, full attention q = self.Wqkv(hidden_states) attn_output = self.attn(q, None, None) attn_output = attn_output.view(-1, self.num_heads * self.head_dim) return self.out_proj(attn_output) class Phi4Mamba(nn.Module): def __init__( self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", # difference dt_scale=1.0, # difference dt_init_floor=1e-4, conv_bias=True, bias=False, use_fast_path=True, # Fused kernel options layer_idx=None, device=None, dtype=None, yoco_cross=False, yoco_kv=False, ): factory_kwargs = {"params_dtype": dtype} # difference super().__init__() self.yoco_cross = yoco_cross self.yoco_kv = yoco_kv self.d_model = d_model self.d_state = d_state self.d_conv = d_conv self.expand = expand self.d_inner = int(self.expand * self.d_model) self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank self.use_fast_path = use_fast_path self.layer_idx = layer_idx self.swiGluActivation = SwiGLUActivation() if self.yoco_cross: self.in_proj = MergedColumnParallelLinear(self.d_model, [self.d_inner], bias=bias, **factory_kwargs) self.out_proj = RowParallelLinear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) return self.conv1d = ColumnParallelLinear( input_size=d_conv, output_size=self.d_inner, bias=conv_bias, params_dtype=dtype, ) # unsqueeze to fit conv1d weights shape into the linear weights shape. # Can't do this in `weight_loader` since it already exists in # `ColumnParallelLinear` and `set_weight_attrs` # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) self.in_proj = MergedColumnParallelLinear( self.d_model, [self.d_inner] * 2, bias=bias, params_dtype=dtype, ) # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( self.d_inner, self.dt_rank + self.d_state * 2, bias=False, params_dtype=dtype, ) # time step projection (discretization) - # In the forward we need to apply dt_proj without the bias, # as the bias is added in the selective scan kernel. self.dt_proj = ColumnParallelLinear( self.dt_rank, self.d_inner, bias=True, skip_bias_add=True, params_dtype=dtype, ) # # D "skip" parameter # self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 self.A = nn.Parameter( torch.empty( self.d_inner, self.d_state, dtype=torch.float32, )) self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32)) self.out_proj = RowParallelLinear( self.d_inner, self.d_model, bias=bias, input_is_parallel=True, params_dtype=dtype, ) self.activation = "silu" def forward(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, yoco_key_values=None) -> torch.Tensor: if self.yoco_cross: out = self.in_proj(hidden_states)[0] out = self.swiGluActivation(yoco_key_values, out) out = self.out_proj(out) return out[0], yoco_key_values # 1. Gated MLP's linear projection # projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) projected_states = self.in_proj( hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1) hidden_states, gate = projected_states.chunk(2, dim=-2) # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if attn_metadata.query_start_loc is not None \ and attn_metadata.context_lens_tensor is not None: # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| hidden_states = causal_conv1d_fn( hidden_states, conv_weights, self.conv1d.bias, activation=self.activation, conv_states=mamba_cache_params.conv_state, has_initial_state=attn_metadata.context_lens_tensor > 0, cache_indices=mamba_cache_params.state_indices_tensor, query_start_loc=attn_metadata.query_start_loc) else: hidden_states = causal_conv1d_update( hidden_states.transpose(0, 1), mamba_cache_params.conv_state, conv_weights, self.conv1d.bias, self.activation, conv_state_indices=mamba_cache_params.state_indices_tensor) hidden_states = hidden_states.transpose(0, 1) # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] time_step, B, C = torch.split( ssm_parameters, [self.dt_rank, self.d_state, self.d_state], dim=-1, ) # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) # 3.c perform the recurrence y ← SSM(A, B, C)(x) time_proj_bias = (self.dt_proj.bias.float() if hasattr( self.dt_proj, "bias") else None) if attn_metadata.query_start_loc is not None \ and attn_metadata.context_lens_tensor is not None: scan_outputs = selective_scan_fn( hidden_states, mamba_cache_params.ssm_state, discrete_time_step, self.A, B.transpose(-2, -1), C.transpose(-2, -1), self.D.float(), # z, None if self.yoco_kv else gate, time_proj_bias, delta_softplus=True, cache_indices=mamba_cache_params.state_indices_tensor, has_initial_state=attn_metadata.context_lens_tensor > 0, query_start_loc=attn_metadata.query_start_loc) else: scan_outputs = selective_state_update( mamba_cache_params.ssm_state, hidden_states.transpose(0, 1), discrete_time_step.transpose(0, 1), self.A, B, C, self.D, # z # gate.transpose(0, 1), None if self.yoco_kv else gate.transpose(0, 1), time_proj_bias, dt_softplus=True, state_batch_indices=mamba_cache_params.state_indices_tensor) scan_outputs = scan_outputs.transpose(0, 1) # 4. Final linear projection if self.yoco_kv: # gate = gate.transpose(-1,-2).contiguous() yoco_key_values = scan_outputs.transpose(-2, -1) scan_outputs = self.swiGluActivation(scan_outputs, gate) contextualized_states = self.out_proj(scan_outputs.transpose(-2, -1))[0] return contextualized_states, yoco_key_values class SambaYDecoderLayer(nn.Module): def __init__( self, config, layer_idx, cache_config, prefix: str = "", ) -> None: super().__init__() self.config = config self.layer_idx = layer_idx self.mlp = SambaYMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.yoco_mb = False self.yoco_cross = False if layer_idx >= config.num_hidden_layers // 2: self.yoco_mb = True self.yoco_cross = (layer_idx >= (config.num_hidden_layers // 2 + 2)) self.use_mamba = config.mb_per_layer > 0 and \ layer_idx % config.mb_per_layer == 0 if self.use_mamba: factory_kwargs = {"dtype": None} self.attn = Phi4Mamba(config.hidden_size, layer_idx=layer_idx, yoco_cross=self.yoco_cross, yoco_kv=self.yoco_mb, **factory_kwargs) else: self.attn = SambaYAttention(config, layer_idx=layer_idx, yoco_cross=self.yoco_cross, cache_config=cache_config, prefix=f"{prefix}.self_attn") self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, positions: torch.Tensor, attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, ssm_output: Optional[torch.LongTensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if self.use_mamba: assert mamba_cache_params is not None else: assert mamba_cache_params is None residual = hidden_states hidden_states = self.input_layernorm( hidden_states.to(dtype=self.input_layernorm.weight.dtype)) if self.use_mamba: attn_outputs, ssm_output = self.attn(hidden_states, attn_metadata, mamba_cache_params, yoco_key_values=ssm_output) residual = residual.to(torch.float32) else: attn_outputs = self.attn(hidden_states, ) hidden_states = residual + attn_outputs residual = hidden_states hidden_states = self.post_attention_layernorm( hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype)) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states, ssm_output class SambaYModel(nn.Module): def __init__(self, config, cache_config=None, quant_config=None, lora_config=None, prefix: str = "") -> None: super().__init__() self.config = config self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, ) # Pipeline parallel is not supported since the second half of # the layers share the kv cache. if get_pp_group().world_size != 1: raise ValueError("Pipeline Parallel not supported") self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: SambaYDecoderLayer(config, int(prefix.split('.')[-1]), cache_config, prefix=prefix), prefix=f"{prefix}.layers") self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] mamba_state_idx = 0 ssm_output = None for i in range(self.start_layer, self.end_layer): layer = self.layers[i] if i == self.config.num_hidden_layers // 2 + 2: # profile run kv_cache_idx = self.config.num_hidden_layers // 2 + 1 cache_layer = self.layers[kv_cache_idx] kv_cache = cache_layer.attn.attn.kv_cache if kv_cache[0].numel() == 0: break # Starting from this layer, we do not need to calculate # the kv cache since we reuse the kv cache from last layer. # If in prefill phase, we can prune> truncate # the hidden state to save computation cost. if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1: selected_token_indices = torch.cumsum( attn_metadata.seq_lens_tensor, dim=0) - 1 hidden_states = hidden_states.index_select( 0, selected_token_indices) ssm_output = ssm_output.index_select( 0, selected_token_indices) if layer.use_mamba: if i < self.config.num_hidden_layers // 2 or \ not layer.yoco_cross: mamba_cache = mamba_cache_params.at_layer_idx( mamba_state_idx) mamba_state_idx += 1 else: mamba_cache = mamba_cache_params.at_layer_idx( mamba_state_idx - 1) hidden_states, ssm_output = layer(hidden_states, positions, attn_metadata, mamba_cache, ssm_output=ssm_output) else: hidden_states, ssm_output = layer( hidden_states, positions, attn_metadata, None, # mamba_cache_params ssm_output=ssm_output) hidden_states = self.final_layernorm( hidden_states.to(dtype=self.final_layernorm.weight.dtype)) return hidden_states class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config quant_config = vllm_config.quant_config scheduler_config = vllm_config.scheduler_config self.compilation_config = vllm_config.compilation_config self.vllm_config = vllm_config # Prefix caching and chunked prefill is not supported for this model. assert not cache_config.enable_prefix_caching, \ "Phi4flash currently does not support prefix caching" assert not scheduler_config.chunked_prefill_enabled, \ "Phi4Flash currently does not support prefix caching" super().__init__() self.config = config self.model_config = vllm_config.model_config self.scheduler_config = scheduler_config self.model = SambaYModel(config, cache_config=cache_config, prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=( DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size), quant_config=quant_config, ) self.embedding_bias = None # Used to track and store by the Mamba cache between steps. self.mamba_cache: Optional[MambaCacheManager] = None self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logits_as_input=False) self.sampler = get_sampler() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: if self.mamba_cache is None: num_mamba_layers = self.config.num_hidden_layers \ // 2 // self.config.mb_per_layer + 1 self.mamba_cache = MambaCacheManager( self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, *self._get_mamba_cache_shape()) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) attn_metadata = get_forward_context().attn_metadata # input_ids and hidden_states isn't a one-to-one mapping in prefill # stage due to YOCO optimization. hidden_states = self.model(input_ids, positions, attn_metadata, mamba_cache_params, intermediate_tensors, inputs_embeds) return hidden_states def _get_mamba_cache_shape( self ) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]: world_size = get_tensor_model_parallel_world_size() hidden_size = self.config.hidden_size mamba_expand = self.config.mamba_expand # 2 mamba_d_conv = self.config.mamba_d_conv # 4 mamba_d_state = self.config.mamba_d_state # 16 conv_state_shape = ( mamba_expand * hidden_size // world_size, mamba_d_conv - 1, ) temporal_state_shape = ( mamba_expand * hidden_size // world_size, mamba_d_state, ) return conv_state_shape, temporal_state_shape def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.mamba_cache.copy_inputs_before_cuda_graphs( input_buffers, **kwargs) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: # If the shape is the same, it means that we have already # prune hidden states manually. prune_hidden_states = hidden_states.size( 0) != sampling_metadata.selected_token_indices.size(0) processed_logits = self.logits_processor( self.lm_head, hidden_states, sampling_metadata, self.embedding_bias, prune_hidden_states=prune_hidden_states) return processed_logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights( self, weights: Iterable[tuple[str, torch.Tensor]], ): weights = {name: weight for name, weight in weights} adjusted_weights = {} for name, weight in weights.items(): if "A_log" in name: name = name.replace("A_log", "A") weight = -torch.exp(weight.float()) if "inner_cross_attn." in name: name = name.replace("inner_cross_attn.", "") adjusted_weights[name] = weight adjusted_weights["lm_head.weight"] = weights[ "model.embed_tokens.weight"] loaded_params: set[str] = set() for name, param in self.named_parameters(): weight = adjusted_weights.get(name) if weight is not None and weight.shape != param.shape: logger.warning("Shape mismatch: %s %s %s", name, weight.shape, param.shape) loaded_params.add(name) missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights, strict=False) assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}" assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" return loaded_params