[Kernel] Mamba support different layout for Conv state (#37416)
This commit is contained in:
@@ -68,6 +68,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
is_conv_state_dim_first,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn,
|
||||
@@ -429,7 +430,13 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
|
||||
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor
|
||||
self_kv_cache = self.kv_cache
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
# conv_state must be (..., dim, width-1) for the conv kernels.
|
||||
# DS layout stores it that way directly; SD layout needs a transpose.
|
||||
conv_state = (
|
||||
self_kv_cache[0]
|
||||
if is_conv_state_dim_first()
|
||||
else self_kv_cache[0].transpose(-1, -2)
|
||||
)
|
||||
ssm_state = self_kv_cache[1]
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
num_accepted_tokens = attn_metadata.num_accepted_tokens
|
||||
|
||||
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
is_conv_state_dim_first,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn,
|
||||
@@ -266,7 +267,13 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
# conv_state must be (..., dim, width-1) for the conv kernels.
|
||||
# DS layout stores it that way directly; SD layout needs a transpose.
|
||||
conv_state = (
|
||||
self_kv_cache[0]
|
||||
if is_conv_state_dim_first()
|
||||
else self_kv_cache[0].transpose(-1, -2)
|
||||
)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
|
||||
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
|
||||
|
||||
Reference in New Issue
Block a user