[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)

Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Thomas Parnell
2025-08-15 14:57:06 +02:00
committed by GitHub
parent 22341b996e
commit 75531a6c13
23 changed files with 467 additions and 87 deletions

View File

@@ -8,7 +8,7 @@ from torch import nn
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import get_current_vllm_config
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
@@ -218,23 +218,23 @@ class MambaMixer2(MambaBase, CustomOp):
**selective** state spaces)
"""
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# For TP, the sharding plan is as follows:
@@ -417,6 +417,8 @@ class MambaMixer2(MambaBase, CustomOp):
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
def forward_native(
@@ -670,7 +672,7 @@ class MambaMixer2(MambaBase, CustomOp):
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
self.head_dim),
)
state_dtype=ssm_state.dtype)
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
@@ -732,6 +734,15 @@ class MambaMixer2(MambaBase, CustomOp):
# 5. Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.mamba2_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=self.intermediate_size,