[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)
Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -18,7 +18,7 @@ from transformers import Zamba2Config
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
@@ -33,7 +33,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -478,6 +478,8 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Zamba2Config,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
"""Initialize the Mamba decoder layer.
|
||||
@@ -502,6 +504,8 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
config.n_mamba_heads,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation="silu",
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer")
|
||||
|
||||
@@ -578,6 +582,8 @@ class Zamba2HybridLayer(nn.Module):
|
||||
shared_transformer: Zamba2AttentionDecoderLayer,
|
||||
config: Zamba2Config,
|
||||
block_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@@ -596,6 +602,8 @@ class Zamba2HybridLayer(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.mamba_decoder = Zamba2MambaDecoderLayer(config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
@@ -669,6 +677,7 @@ class Zamba2Model(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
@@ -718,11 +727,15 @@ class Zamba2Model(nn.Module):
|
||||
Zamba2HybridLayer(block,
|
||||
config,
|
||||
block_idx,
|
||||
quant_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix))
|
||||
else:
|
||||
layers.append(
|
||||
Zamba2MambaDecoderLayer(config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix))
|
||||
self.layers = nn.ModuleList(layers)
|
||||
@@ -848,6 +861,18 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
"1.weight": "B.weight",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
@@ -966,10 +991,13 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
mamba_state_dtype = \
|
||||
self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.lm_head.weight.dtype,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape)
|
||||
*mamba_state_shape,
|
||||
*mamba_state_dtype)
|
||||
|
||||
# Get cache parameters for current run
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user