[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)

Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Thomas Parnell
2025-08-15 14:57:06 +02:00
committed by GitHub
parent 22341b996e
commit 75531a6c13
23 changed files with 467 additions and 87 deletions

View File

@@ -26,7 +26,7 @@ from torch import nn
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
@@ -40,7 +40,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@@ -110,6 +110,7 @@ class NemotronHMLPDecoderLayer(nn.Module):
self,
config: NemotronHConfig,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@@ -149,6 +150,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
self,
config: NemotronHConfig,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@@ -167,6 +169,8 @@ class NemotronHMambaDecoderLayer(nn.Module):
head_dim=config.mamba_head_dim,
rms_norm_eps=config.rms_norm_eps,
activation=config.mamba_hidden_act,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
)
@@ -198,6 +202,7 @@ class NemotronHAttention(nn.Module):
self,
config: NemotronHConfig,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@@ -270,6 +275,7 @@ class NemotronHAttentionDecoderLayer(nn.Module):
self,
config: NemotronHConfig,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@@ -279,6 +285,7 @@ class NemotronHAttentionDecoderLayer(nn.Module):
self.mixer = NemotronHAttention(
config,
layer_idx,
model_config,
cache_config,
quant_config,
prefix=f"{prefix}.mixer",
@@ -317,6 +324,7 @@ class NemotronHModel(nn.Module):
super().__init__()
config: NemotronHConfig = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@@ -340,6 +348,7 @@ class NemotronHModel(nn.Module):
return layer_class(
config,
layer_idx,
model_config,
cache_config,
quant_config=quant_config,
prefix=prefix,
@@ -478,6 +487,18 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.mamba2_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls,
@@ -569,10 +590,13 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_mamba_layers,
*mamba_state_shape)
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)