[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)
Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -9,13 +9,13 @@ from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -40,6 +40,7 @@ class MambaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: MambaConfig,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
is_lora_enabled: Optional[bool] = False,
|
||||
@@ -61,6 +62,8 @@ class MambaDecoderLayer(nn.Module):
|
||||
rms_norm_eps=mixer_rms_eps,
|
||||
activation=config.hidden_act,
|
||||
is_lora_enabled=self.is_lora_enabled,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.mixer")
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
@@ -88,6 +91,7 @@ class MambaModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
@@ -108,6 +112,7 @@ class MambaModel(nn.Module):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MambaDecoderLayer(config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
is_lora_enabled=is_lora_enabled,
|
||||
@@ -243,9 +248,11 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
state_shape = self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config)
|
||||
state_dtype = self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.lm_head.weight.dtype,
|
||||
num_layers, *state_shape)
|
||||
num_layers, *state_shape,
|
||||
*state_dtype)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
@@ -254,6 +261,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
|
||||
return hidden_states
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.mamba1_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user