[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)

Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Thomas Parnell
2025-08-15 14:57:06 +02:00
committed by GitHub
parent 22341b996e
commit 75531a6c13
23 changed files with 467 additions and 87 deletions

View File

@@ -9,7 +9,7 @@ from torch.nn.parameter import Parameter
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import get_current_vllm_config
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
@@ -56,6 +56,8 @@ class MambaMixer(MambaBase, CustomOp):
rms_norm_eps: float = 1e-5,
activation="silu",
is_lora_enabled: bool = False,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = ""):
super().__init__()
self.time_step_rank = time_step_rank
@@ -153,6 +155,8 @@ class MambaMixer(MambaBase, CustomOp):
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
def _ssm_transform(
@@ -369,6 +373,15 @@ class MambaMixer(MambaBase, CustomOp):
return out
def get_state_dtype(self) -> tuple[torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.mamba1_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba1_state_shape(
tp_world_size=get_tensor_model_parallel_world_size(),

View File

@@ -8,7 +8,7 @@ from torch import nn
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import get_current_vllm_config
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
@@ -218,23 +218,23 @@ class MambaMixer2(MambaBase, CustomOp):
**selective** state spaces)
"""
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# For TP, the sharding plan is as follows:
@@ -417,6 +417,8 @@ class MambaMixer2(MambaBase, CustomOp):
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
def forward_native(
@@ -670,7 +672,7 @@ class MambaMixer2(MambaBase, CustomOp):
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
self.head_dim),
)
state_dtype=ssm_state.dtype)
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
@@ -732,6 +734,15 @@ class MambaMixer2(MambaBase, CustomOp):
# 5. Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.mamba2_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=self.intermediate_size,

View File

@@ -1,6 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Union
import torch
from vllm.config import MambaDType, ModelDType
from vllm.distributed import divide
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype
class MambaStateDtypeCalculator:
@classmethod
def linear_attention_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
# TODO (tdoublep) requires testing
if mamba_cache_dtype == "float32":
raise ValueError("fp32 state for minimax is not yet supported")
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, )
@classmethod
def mamba1_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
# TODO (tdoublep) requires kernel changes
if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32":
raise ValueError("fp32 state for mamba1 is not yet supported")
else:
return MambaStateDtypeCalculator.mamba2_state_dtype(
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype)
@classmethod
def mamba2_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
model_dtype)
if mamba_ssm_cache_dtype == "auto":
temporal_state_dtype = conv_state_dtype
else:
temporal_state_dtype = (
STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype])
return (conv_state_dtype, temporal_state_dtype)
class MambaStateShapeCalculator:

View File

@@ -41,6 +41,7 @@ def _mamba_chunk_scan_combined_fwd(x,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
state_dtype=None,
out=None):
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
batch, seqlen, nheads, headdim = x.shape
@@ -118,7 +119,7 @@ def _mamba_chunk_scan_combined_fwd(x,
if initial_states is not None else None,
seq_idx=seq_idx,
chunk_size=chunk_size,
out_dtype=C.dtype,
out_dtype=state_dtype if state_dtype is not None else C.dtype,
is_cont_batched=cu_seqlens is not None)
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
for t in [states, final_states])
@@ -189,7 +190,8 @@ def mamba_chunk_scan_combined(x,
dt_limit=(0.0, float("inf")),
out=None,
return_final_states=False,
return_varlen_states=False):
return_varlen_states=False,
state_dtype=None):
"""
Argument:
x: (batch, seqlen, nheads, headdim)
@@ -206,6 +208,7 @@ def mamba_chunk_scan_combined(x,
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
dt_softplus: Whether to apply softplus to dt
out: Preallocated output tensor
state_dtype: The data type of the ssm state
"""
if not return_varlen_states:
@@ -229,7 +232,8 @@ def mamba_chunk_scan_combined(x,
cu_seqlens=cu_seqlens,
dt_softplus=dt_softplus,
dt_limit=dt_limit,
out=out)
out=out,
state_dtype=state_dtype)
if not return_varlen_states:
if not return_final_states:
return