[Kernel] Mamba support different layout for Conv state (#37416)

This commit is contained in:
Nicolò Lucchesi
2026-04-03 01:50:09 +02:00
committed by GitHub
parent bb39382b2b
commit 66e86f1dbd
11 changed files with 169 additions and 39 deletions

View File

@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
is_conv_state_dim_first,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
@@ -575,10 +576,15 @@ class MambaMixer2(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a
# transpose (which keeps dim contiguous via stride tricks).
conv_state = (
self.kv_cache[0]
if is_conv_state_dim_first()
else self.kv_cache[0].transpose(-1, -2)
)
ssm_state = self.kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size