[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)
Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -27,12 +27,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
DeviceConfig, DistributedExecutorBackend,
|
||||
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
|
||||
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
||||
PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
|
||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
|
||||
get_field)
|
||||
LoRAConfig, MambaDType, ModelConfig, ModelDType,
|
||||
ModelImpl, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
||||
RunnerOption, SchedulerConfig, SchedulerPolicy,
|
||||
SpeculativeConfig, TaskOption, TokenizerMode,
|
||||
VllmConfig, get_attr_docs, get_field)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.plugins import load_general_plugins
|
||||
@@ -422,6 +422,8 @@ class EngineArgs:
|
||||
override_attention_dtype: str = ModelConfig.override_attention_dtype
|
||||
|
||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
|
||||
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
|
||||
|
||||
additional_config: dict[str, Any] = \
|
||||
get_field(VllmConfig, "additional_config")
|
||||
@@ -694,6 +696,10 @@ class EngineArgs:
|
||||
**cache_kwargs["calculate_kv_scales"])
|
||||
cache_group.add_argument("--kv-sharing-fast-prefill",
|
||||
**cache_kwargs["kv_sharing_fast_prefill"])
|
||||
cache_group.add_argument("--mamba-cache-dtype",
|
||||
**cache_kwargs["mamba_cache_dtype"])
|
||||
cache_group.add_argument("--mamba-ssm-cache-dtype",
|
||||
**cache_kwargs["mamba_ssm_cache_dtype"])
|
||||
|
||||
# Multimodal related configs
|
||||
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
||||
@@ -1105,6 +1111,8 @@ class EngineArgs:
|
||||
cpu_offload_gb=self.cpu_offload_gb,
|
||||
calculate_kv_scales=self.calculate_kv_scales,
|
||||
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
|
||||
mamba_cache_dtype=self.mamba_cache_dtype,
|
||||
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
ray_runtime_env = None
|
||||
|
||||
Reference in New Issue
Block a user