[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)

Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Thomas Parnell
2025-08-15 14:57:06 +02:00
committed by GitHub
parent 22341b996e
commit 75531a6c13
23 changed files with 467 additions and 87 deletions

View File

@@ -16,7 +16,8 @@ from transformers import MiniMaxConfig
from vllm import envs
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_rank,
@@ -36,7 +37,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -338,6 +339,12 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def mamba_type(self) -> str:
return "linear_attention"
def get_state_dtype(self) -> tuple[torch.dtype]:
return MambaStateDtypeCalculator.linear_attention_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.linear_attention_state_shape(
num_heads=self.num_heads,
@@ -353,6 +360,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
max_position: int,
block_size: int,
num_hidden_layer: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
layer_idx: int = 0,
linear_layer_idx: int = 0,
@@ -374,6 +383,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
self.tp_heads = self.total_num_heads // self.tp_size
self.qkv_size = self.num_heads * self.head_dim
self.tp_hidden = self.head_dim * self.tp_heads
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
self.qkv_proj = ColumnParallelLinear(
@@ -657,6 +668,7 @@ class MiniMaxText01DecoderLayer(nn.Module):
def __init__(
self,
config: MiniMaxConfig,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
expert_num: int = 1,
@@ -693,6 +705,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
max_position=max_position_embeddings,
block_size=config.block if hasattr(config, "block") else 256,
num_hidden_layer=config.num_hidden_layers,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
layer_idx=self._ilayer,
linear_layer_idx=linear_layer_id,
@@ -861,6 +875,7 @@ class MiniMaxText01Model(nn.Module):
def __init__(
self,
config: MiniMaxConfig,
model_config: Optional[ModelConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
scheduler_config=None,
@@ -910,6 +925,7 @@ class MiniMaxText01Model(nn.Module):
decoder_kwargs = {
"quant_config": quant_config,
"layer_id": layer_idx,
"model_config": model_config,
"cache_config": cache_config
}
@@ -1111,8 +1127,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
self.config.max_model_len = vllm_config.model_config.max_model_len
self.model = MiniMaxText01Model(
self.config,
quant_config,
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=quant_config,
scheduler_config=vllm_config.scheduler_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
@@ -1409,6 +1426,17 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
load_basic_weight(name, loaded_weight, self)
return loaded_params
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.linear_attention_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls,