[Kernel] Mamba support different layout for Conv state (#37416)

This commit is contained in:
Nicolò Lucchesi
2026-04-03 01:50:09 +02:00
committed by GitHub
parent bb39382b2b
commit 66e86f1dbd
11 changed files with 169 additions and 39 deletions

View File

@@ -68,6 +68,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
is_conv_state_dim_first,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
@@ -429,7 +430,13 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state = (
self_kv_cache[0]
if is_conv_state_dim_first()
else self_kv_cache[0].transpose(-1, -2)
)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens

View File

@@ -32,6 +32,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
is_conv_state_dim_first,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
@@ -266,7 +267,13 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state = (
self_kv_cache[0]
if is_conv_state_dim_first()
else self_kv_cache[0].transpose(-1, -2)
)
ssm_state = self_kv_cache[1]
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d