[V1] Remove V0 code paths for Hybrid models (#25400)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-09-23 17:26:13 +02:00
committed by GitHub
parent 2c58742dff
commit a903669e10
31 changed files with 352 additions and 2296 deletions

View File

@@ -9,7 +9,6 @@ import torch
from torch import nn
from transformers import JambaConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
@@ -30,10 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaMLP as JambaMLP
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
@@ -145,7 +141,6 @@ class JambaMambaDecoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
**kwargs,
):
if residual is None:
@@ -156,7 +151,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states, residual)
output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params)
self.mamba(hidden_states, output)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(output, residual)
hidden_states = self.feed_forward(hidden_states)
@@ -333,7 +328,6 @@ class JambaModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@@ -348,24 +342,11 @@ class JambaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
kv_cache_index = 0
mamba_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer):
layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache_index += 1
if isinstance(layer,
JambaMambaDecoderLayer) and mamba_cache_params:
current_state_layer = mamba_cache_index
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
current_state_layer)
mamba_cache_index += 1
hidden_states, residual = layer(positions=positions,
hidden_states=hidden_states,
residual=residual)
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=layer_mamba_cache_params)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
@@ -503,8 +484,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
@@ -521,24 +500,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config)
state_dtype = self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_layers, *state_shape,
*state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
@@ -574,7 +538,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_size=hf_config.mamba_expand * hidden_size,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=envs.VLLM_USE_V1,
)
def compute_logits(