[V1] Remove V0 code paths for Hybrid models (#25400)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-09-23 17:26:13 +02:00
committed by GitHub
parent 2c58742dff
commit a903669e10
31 changed files with 352 additions and 2296 deletions

View File

@@ -8,16 +8,11 @@ import torch
from torch import nn
from transformers import MambaConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
@@ -28,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
@@ -74,8 +66,6 @@ class Mamba2DecoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if residual is None:
@@ -85,7 +75,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
self.mixer(hidden_states, output)
return output, residual
@@ -137,7 +127,6 @@ class Mamba2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@@ -152,25 +141,10 @@ class Mamba2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer) if mamba_cache_params else None,
mamba2_metadata=mamba2_metadata)
hidden_states, residual = layer(positions=positions,
hidden_states=hidden_states,
residual=residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
@@ -222,13 +196,11 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
@@ -247,7 +219,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
head_dim=hf_config.head_dim,
state_size=hf_config.state_size,
conv_kernel=hf_config.conv_kernel,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -282,9 +253,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
@@ -300,29 +268,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
hidden_states = self.backbone(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states