[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -16,6 +16,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -261,6 +263,10 @@ class MambaForCausalLM(
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba1_state_copy_func()
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user