[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)

Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Thomas Parnell
2025-08-15 14:57:06 +02:00
committed by GitHub
parent 22341b996e
commit 75531a6c13
23 changed files with 467 additions and 87 deletions

View File

@@ -27,12 +27,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
get_field)
LoRAConfig, MambaDType, ModelConfig, ModelDType,
ModelImpl, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
RunnerOption, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode,
VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
@@ -422,6 +422,8 @@ class EngineArgs:
override_attention_dtype: str = ModelConfig.override_attention_dtype
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
additional_config: dict[str, Any] = \
get_field(VllmConfig, "additional_config")
@@ -694,6 +696,10 @@ class EngineArgs:
**cache_kwargs["calculate_kv_scales"])
cache_group.add_argument("--kv-sharing-fast-prefill",
**cache_kwargs["kv_sharing_fast_prefill"])
cache_group.add_argument("--mamba-cache-dtype",
**cache_kwargs["mamba_cache_dtype"])
cache_group.add_argument("--mamba-ssm-cache-dtype",
**cache_kwargs["mamba_ssm_cache_dtype"])
# Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig)
@@ -1105,6 +1111,8 @@ class EngineArgs:
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
)
ray_runtime_env = None