[V1] Remove V0 code paths for Hybrid models (#25400)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-09-23 17:26:13 +02:00
committed by GitHub
parent 2c58742dff
commit a903669e10
31 changed files with 352 additions and 2296 deletions

View File

@@ -14,7 +14,6 @@ import torch.distributed
from torch import nn
from transformers import MiniMaxConfig
from vllm import envs
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
@@ -44,7 +43,6 @@ from vllm.model_executor.models.utils import maybe_prefix
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
@@ -404,7 +402,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
def forward(self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
kv_caches: Union[list[dict], Optional[torch.Tensor]],
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
is_warmup: bool = False,
@@ -418,7 +415,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
hidden_states=layernorm_output,
output=self_attention_output,
positions=positions,
kv_caches=kv_caches,
)
residual = residual * self.layernorm_attention_alpha
@@ -563,10 +559,6 @@ class MiniMaxText01Model(nn.Module):
self._dtype = _dummy.dtype
del _dummy
if not envs.VLLM_USE_V1:
self.minimax_cache = MinimaxCacheManager(
dtype=torch.float32, cache_shape=self.cache_shape)
norm_kwargs = {}
if hasattr(config, "rms_norm_eps"):
norm_kwargs["eps"] = config.rms_norm_eps
@@ -614,25 +606,6 @@ class MiniMaxText01Model(nn.Module):
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if not envs.VLLM_USE_V1 and attn_metadata is None:
return None
if not envs.VLLM_USE_V1:
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []
(
minimax_cache_tensors,
state_indices_tensor,
) = self.minimax_cache.current_run_tensors(**kwargs)
if getattr(attn_metadata, "num_prefills", 0) > 0:
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
**kwargs)
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
state_indices_tensor)
else:
minimax_cache_params = None
if get_pp_group().is_first_rank:
if inputs_embeds is None:
@@ -645,20 +618,10 @@ class MiniMaxText01Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
minimax_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer):
_caches = None
if not envs.VLLM_USE_V1 and isinstance(
layer.self_attn, MiniMaxText01LinearAttention):
current_state_layer = minimax_cache_index
_caches = minimax_cache_params.at_layer_idx(
current_state_layer)
minimax_cache_index += 1
hidden_states, residual = layer(
hidden_states=hidden_states,
positions=positions,
kv_caches=_caches,
attn_metadata=attn_metadata,
residual=residual,
)
@@ -1003,13 +966,11 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, ...], ...]:
"""Calculate shape for MiniMaxText01LinearAttention cache.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing: