[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -455,6 +457,10 @@ class BambaForCausalLM(
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
@@ -330,26 +330,54 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
if model_config.supports_mamba_prefix_caching:
|
||||
logger.info(
|
||||
"Warning: Prefix caching is currently enabled. "
|
||||
"Its support for Mamba layers is experimental. "
|
||||
"Please report any issues you may observe."
|
||||
if cache_config.mamba_cache_mode == "none":
|
||||
cache_config.mamba_cache_mode = (
|
||||
"all" if model_config.supports_mamba_prefix_caching else "align"
|
||||
)
|
||||
# By default, mamba block size will be set to max_model_len (see
|
||||
# below). When enabling prefix caching, we align mamba block size
|
||||
# to the block size as the basic granularity for prefix caching.
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = cache_config.block_size
|
||||
else:
|
||||
logger.info(
|
||||
"Hybrid or mamba-based model detected without "
|
||||
"support for prefix caching: disabling."
|
||||
logger.warning(
|
||||
"Mamba cache mode is set to '%s' for %s by default "
|
||||
"when prefix caching is enabled",
|
||||
cache_config.mamba_cache_mode,
|
||||
model_config.architecture,
|
||||
)
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = model_config.max_model_len
|
||||
if (
|
||||
cache_config.mamba_cache_mode == "all"
|
||||
and not model_config.supports_mamba_prefix_caching
|
||||
):
|
||||
cache_config.mamba_cache_mode = "align"
|
||||
logger.warning(
|
||||
"Hybrid or mamba-based model detected without support "
|
||||
"for prefix caching with Mamba cache 'all' mode: "
|
||||
"falling back to 'align' mode."
|
||||
)
|
||||
if cache_config.mamba_cache_mode == "align":
|
||||
assert vllm_config.scheduler_config.enable_chunked_prefill, (
|
||||
"Chunked prefill is required for mamba cache mode 'align'."
|
||||
)
|
||||
assert not vllm_config.speculative_config, (
|
||||
"Mamba cache mode 'align' is currently not compatible "
|
||||
"with speculative decoding."
|
||||
)
|
||||
logger.info(
|
||||
"Warning: Prefix caching in Mamba cache '%s' "
|
||||
"mode is currently enabled. "
|
||||
"Its support for Mamba layers is experimental. "
|
||||
"Please report any issues you may observe.",
|
||||
cache_config.mamba_cache_mode,
|
||||
)
|
||||
# By default, mamba block size will be set to max_model_len (see
|
||||
# below). When enabling prefix caching, we align mamba block size
|
||||
# to the block size as the basic granularity for prefix caching.
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = cache_config.block_size
|
||||
else:
|
||||
if cache_config.mamba_cache_mode != "none":
|
||||
cache_config.mamba_cache_mode = "none"
|
||||
logger.warning(
|
||||
"Mamba cache mode is set to 'none' when prefix caching is disabled"
|
||||
)
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = model_config.max_model_len
|
||||
|
||||
|
||||
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
@@ -426,7 +454,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
mamba_page_size = MambaSpec(
|
||||
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
||||
block_size=model_config.max_model_len,
|
||||
block_size=-1, # block_size doesn't matter for mamba page size
|
||||
).page_size_bytes
|
||||
|
||||
# Model may be marked as is_hybrid
|
||||
@@ -435,7 +463,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
if mamba_page_size == 0:
|
||||
return
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
if cache_config.mamba_cache_mode == "all":
|
||||
# With prefix caching, select attention block size to
|
||||
# optimize for mamba kernel performance
|
||||
|
||||
@@ -479,6 +507,13 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
attn_block_size,
|
||||
)
|
||||
|
||||
# By default, mamba block size will be set to max_model_len.
|
||||
# When enabling prefix caching and using align mamba cache
|
||||
# mode, we align mamba block size to the block size as the
|
||||
# basic granularity for prefix caching.
|
||||
if cache_config.mamba_cache_mode == "align":
|
||||
cache_config.mamba_block_size = cache_config.block_size
|
||||
|
||||
# compute new attention page size
|
||||
attn_page_size = cache_config.block_size * attn_page_size_1_token
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -551,6 +553,10 @@ class FalconH1ForCausalLM(
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
@@ -19,6 +19,8 @@ from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLine
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -641,6 +643,10 @@ class GraniteMoeHybridForCausalLM(
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.config import ModelConfig, SpeechToTextConfig
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils.collection_utils import common_prefix
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
@@ -776,6 +777,19 @@ class IsHybrid(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, ...]:
|
||||
"""Calculate copy-function callables for each Mamba state.
|
||||
|
||||
Returns:
|
||||
A tuple of MambaStateCopyFunc callables that correspond, in order,
|
||||
to the Mamba states produced by the model. Each callable accepts
|
||||
(state, block_ids, cur_block_idx, num_accepted_tokens) and returns
|
||||
a MambaCopySpec describing the memory-copy parameters for prefix
|
||||
caching in align mode.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def is_hybrid(model: object) -> TypeIs[IsHybrid]: ...
|
||||
|
||||
@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -558,6 +560,10 @@ class JambaForCausalLM(
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba1_state_copy_func()
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -26,6 +26,8 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -544,6 +546,14 @@ class KimiLinearForCausalLM(
|
||||
num_spec=num_spec,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(
|
||||
cls,
|
||||
) -> tuple[
|
||||
MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc
|
||||
]:
|
||||
return MambaStateCopyFuncCalculator.kda_state_copy_func()
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -20,6 +20,8 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -459,14 +461,19 @@ class Lfm2ForCausalLM(
|
||||
conv_kernel=hf_config.conv_L_cache,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.short_conv_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
assert not cache_config.enable_prefix_caching, (
|
||||
"Lfm2 currently does not support prefix caching"
|
||||
)
|
||||
if cache_config.mamba_cache_mode == "all":
|
||||
raise NotImplementedError(
|
||||
"Lfm2 currently does not support 'all' prefix caching, "
|
||||
"please use '--mamba-cache-mode=align' instead"
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -25,6 +25,8 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -640,6 +642,10 @@ class Lfm2MoeForCausalLM(
|
||||
conv_kernel=hf_config.conv_L_cache,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.short_conv_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
@@ -16,6 +16,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -261,6 +263,10 @@ class MambaForCausalLM(
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba1_state_copy_func()
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -228,6 +230,10 @@ class Mamba2ForCausalLM(
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
|
||||
@@ -35,6 +35,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01LinearAttention
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -1006,3 +1008,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
tp_size=parallel_config.tensor_parallel_size,
|
||||
head_dim=hf_config.head_dim,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.linear_attention_state_copy_func()
|
||||
|
||||
@@ -2128,3 +2128,7 @@ class NemotronH_Nano_VL_V2(
|
||||
temp_vllm_config = copy.deepcopy(vllm_config)
|
||||
temp_vllm_config.model_config.hf_config = text_config
|
||||
return NemotronHForCausalLM.get_mamba_state_dtype_from_config(temp_vllm_config)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls):
|
||||
return NemotronHForCausalLM.get_mamba_state_copy_func()
|
||||
|
||||
@@ -45,6 +45,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -809,6 +811,10 @@ class NemotronHForCausalLM(
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
@@ -27,6 +27,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -899,6 +901,10 @@ class Plamo2ForCausalLM(
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -48,6 +48,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -1205,9 +1207,11 @@ class Qwen3NextForCausalLM(
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, (
|
||||
"Qwen3Next currently does not support prefix caching"
|
||||
)
|
||||
if cache_config.mamba_cache_mode == "all":
|
||||
raise NotImplementedError(
|
||||
"Qwen3Next currently does not support 'all' prefix caching, "
|
||||
"please use '--mamba-cache-mode=align' instead"
|
||||
)
|
||||
self.quant_config = vllm_config.quant_config
|
||||
|
||||
super().__init__()
|
||||
@@ -1278,6 +1282,10 @@ class Qwen3NextForCausalLM(
|
||||
num_spec,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -234,9 +234,11 @@ class Qwen3NextMTP(nn.Module, QwenNextMixtureOfExperts):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
cache_config = vllm_config.cache_config
|
||||
assert not cache_config.enable_prefix_caching, (
|
||||
"Qwen3NextMTP currently does not support prefix caching"
|
||||
)
|
||||
if cache_config.mamba_cache_mode == "all":
|
||||
raise NotImplementedError(
|
||||
"Qwen3NextMTP currently does not support 'all' prefix caching, "
|
||||
"please use '--mamba-cache-mode=align' instead"
|
||||
)
|
||||
|
||||
self.quant_config = vllm_config.quant_config
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -891,6 +893,10 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
"""Initialize the Zamba2 model for causal language modeling.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user