[V1] Remove V0 code paths for Hybrid models (#25400)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -14,7 +14,6 @@ import torch.distributed
|
||||
from torch import nn
|
||||
from transformers import MiniMaxConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
@@ -44,7 +43,6 @@ from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid
|
||||
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
|
||||
|
||||
@@ -404,7 +402,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Union[list[dict], Optional[torch.Tensor]],
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
is_warmup: bool = False,
|
||||
@@ -418,7 +415,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
hidden_states=layernorm_output,
|
||||
output=self_attention_output,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
)
|
||||
|
||||
residual = residual * self.layernorm_attention_alpha
|
||||
@@ -563,10 +559,6 @@ class MiniMaxText01Model(nn.Module):
|
||||
self._dtype = _dummy.dtype
|
||||
del _dummy
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
self.minimax_cache = MinimaxCacheManager(
|
||||
dtype=torch.float32, cache_shape=self.cache_shape)
|
||||
|
||||
norm_kwargs = {}
|
||||
if hasattr(config, "rms_norm_eps"):
|
||||
norm_kwargs["eps"] = config.rms_norm_eps
|
||||
@@ -614,25 +606,6 @@ class MiniMaxText01Model(nn.Module):
|
||||
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if not envs.VLLM_USE_V1 and attn_metadata is None:
|
||||
return None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if "request_ids_to_seq_ids" not in kwargs:
|
||||
kwargs["request_ids_to_seq_ids"] = {}
|
||||
if "finished_requests_ids" not in kwargs:
|
||||
kwargs["finished_requests_ids"] = []
|
||||
(
|
||||
minimax_cache_tensors,
|
||||
state_indices_tensor,
|
||||
) = self.minimax_cache.current_run_tensors(**kwargs)
|
||||
if getattr(attn_metadata, "num_prefills", 0) > 0:
|
||||
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
|
||||
**kwargs)
|
||||
|
||||
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
|
||||
state_indices_tensor)
|
||||
else:
|
||||
minimax_cache_params = None
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is None:
|
||||
@@ -645,20 +618,10 @@ class MiniMaxText01Model(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
minimax_cache_index = 0
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
_caches = None
|
||||
if not envs.VLLM_USE_V1 and isinstance(
|
||||
layer.self_attn, MiniMaxText01LinearAttention):
|
||||
current_state_layer = minimax_cache_index
|
||||
_caches = minimax_cache_params.at_layer_idx(
|
||||
current_state_layer)
|
||||
minimax_cache_index += 1
|
||||
hidden_states, residual = layer(
|
||||
hidden_states=hidden_states,
|
||||
positions=positions,
|
||||
kv_caches=_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
)
|
||||
@@ -1003,13 +966,11 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, ...], ...]:
|
||||
"""Calculate shape for MiniMaxText01LinearAttention cache.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
use_v1: Get shapes for V1 (or V0)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
|
||||
Reference in New Issue
Block a user