[Kernel] Mamba support different layout for Conv state (#37416)
This commit is contained in:
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
is_conv_state_dim_first,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn,
|
||||
@@ -575,10 +576,15 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
# conv_state must be (..., dim, width-1) for the conv kernels.
|
||||
# DS layout stores it that way directly; SD layout needs a
|
||||
# transpose (which keeps dim contiguous via stride tricks).
|
||||
conv_state = (
|
||||
self.kv_cache[0]
|
||||
if is_conv_state_dim_first()
|
||||
else self.kv_cache[0].transpose(-1, -2)
|
||||
)
|
||||
ssm_state = self.kv_cache[1]
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
|
||||
Reference in New Issue
Block a user