[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)
Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -16,7 +16,8 @@ from transformers import MiniMaxConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tensor_model_parallel_rank,
|
||||
@@ -36,7 +37,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -338,6 +339,12 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
def mamba_type(self) -> str:
|
||||
return "linear_attention"
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
||||
self.model_config.dtype,
|
||||
self.cache_config.mamba_cache_dtype,
|
||||
)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.linear_attention_state_shape(
|
||||
num_heads=self.num_heads,
|
||||
@@ -353,6 +360,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
max_position: int,
|
||||
block_size: int,
|
||||
num_hidden_layer: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_idx: int = 0,
|
||||
linear_layer_idx: int = 0,
|
||||
@@ -374,6 +383,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
self.tp_heads = self.total_num_heads // self.tp_size
|
||||
self.qkv_size = self.num_heads * self.head_dim
|
||||
self.tp_hidden = self.head_dim * self.tp_heads
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.prefix = prefix
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
@@ -657,6 +668,7 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MiniMaxConfig,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
expert_num: int = 1,
|
||||
@@ -693,6 +705,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
max_position=max_position_embeddings,
|
||||
block_size=config.block if hasattr(config, "block") else 256,
|
||||
num_hidden_layer=config.num_hidden_layers,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
layer_idx=self._ilayer,
|
||||
linear_layer_idx=linear_layer_id,
|
||||
@@ -861,6 +875,7 @@ class MiniMaxText01Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MiniMaxConfig,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
scheduler_config=None,
|
||||
@@ -910,6 +925,7 @@ class MiniMaxText01Model(nn.Module):
|
||||
decoder_kwargs = {
|
||||
"quant_config": quant_config,
|
||||
"layer_id": layer_idx,
|
||||
"model_config": model_config,
|
||||
"cache_config": cache_config
|
||||
}
|
||||
|
||||
@@ -1111,8 +1127,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
self.config.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.model = MiniMaxText01Model(
|
||||
self.config,
|
||||
quant_config,
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=quant_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
if get_pp_group().is_last_rank:
|
||||
@@ -1409,6 +1426,17 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
load_basic_weight(name, loaded_weight, self)
|
||||
return loaded_params
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user