diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 0f587558b..01d395b1e 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -60,6 +60,14 @@ MAX_NUM_SEQS = 4 ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "auto" +def _set_conv_state_layout(monkeypatch, layout: str) -> None: + """Set conv state layout env var and clear cache to pick up new value.""" + from vllm.model_executor.layers.mamba import mamba_utils + + monkeypatch.setenv("VLLM_SSM_CONV_STATE_LAYOUT", layout) + mamba_utils.get_conv_state_layout.cache_clear() + + @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @@ -102,12 +110,15 @@ def test_models( @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"]) def test_batching( vllm_runner, example_prompts, + monkeypatch, model: str, max_tokens: int, num_logprobs: int, + conv_state_layout: str, ) -> None: try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) @@ -116,6 +127,8 @@ def test_batching( except ValueError: pass + _set_conv_state_layout(monkeypatch, conv_state_layout) + for_loop_outputs = [] with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: for prompt in example_prompts: @@ -138,11 +151,14 @@ def test_batching( @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [10]) +@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"]) def test_chunked_prefill_with_parallel_sampling( vllm_runner, example_prompts, + monkeypatch, model: str, max_tokens: int, + conv_state_layout: str, ) -> None: """ Tests chunked prefill in conjunction with n > 1. @@ -154,6 +170,8 @@ def test_chunked_prefill_with_parallel_sampling( decoding steps inside a chunked prefill forward pass (where we have both prefill and decode together) """ + _set_conv_state_layout(monkeypatch, conv_state_layout) + sampling_params = SamplingParams(n=3, temperature=1, seed=0, max_tokens=max_tokens) with vllm_runner( model, @@ -168,17 +186,22 @@ def test_chunked_prefill_with_parallel_sampling( @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [20]) +@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"]) def test_mamba_cache_cg_padding( vllm_runner, example_prompts, + monkeypatch, model: str, max_tokens: int, + conv_state_layout: str, ) -> None: """ This test is for verifying that mamba cache is padded to CG captured batch size. If it's not, a torch RuntimeError will be raised because tensor dimensions aren't compatible. """ + _set_conv_state_layout(monkeypatch, conv_state_layout) + vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config() cudagraph_dispatcher = CudagraphDispatcher(vllm_config) cudagraph_dispatcher.initialize_cudagraph_keys( diff --git a/vllm/envs.py b/vllm/envs.py index 0a40030cf..c2f8ca8c5 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -191,6 +191,7 @@ if TYPE_CHECKING: VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None + VLLM_SSM_CONV_STATE_LAYOUT: Literal["SD", "DS"] | None = None VLLM_COMPUTE_NANS_IN_LOGITS: bool = False VLLM_USE_NVFP4_CT_EMULATIONS: bool = False VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: Literal[ @@ -1409,6 +1410,13 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_KV_CACHE_LAYOUT": env_with_choices( "VLLM_KV_CACHE_LAYOUT", None, ["NHD", "HND"] ), + # SSM conv state layout used for Mamba models. + # - SD: (state_len, dim) — dim contiguous (default) + # - DS: (dim, state_len) — TP-sharded dim on dim1, + # consistent with SSM temporal state and HND KV cache layout. + "VLLM_SSM_CONV_STATE_LAYOUT": env_with_choices( + "VLLM_SSM_CONV_STATE_LAYOUT", None, ["SD", "DS"] + ), # Enable checking whether the generated logits contain NaNs, # indicating corrupted output. Useful for debugging low level bugs # or bad hardware but it may add compute overhead. diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index 46db5dc32..b09f980c7 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -31,7 +31,11 @@ from .linear import ( RowParallelLinear, ) from .mamba.abstract import MambaBase -from .mamba.mamba_utils import MambaStateDtypeCalculator, MambaStateShapeCalculator +from .mamba.mamba_utils import ( + MambaStateDtypeCalculator, + MambaStateShapeCalculator, + is_conv_state_dim_first, +) from .mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update from .quantization.base_config import QuantizationConfig @@ -315,10 +319,12 @@ class KimiDeltaAttention(nn.Module, MambaBase): beta = beta[:num_actual_tokens] (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches - # deal with strides - conv_state_q = conv_state_q.transpose(-1, -2) - conv_state_k = conv_state_k.transpose(-1, -2) - conv_state_v = conv_state_v.transpose(-1, -2) + # conv_state must be (..., dim, width-1) for the conv kernels. + # DS layout stores it that way directly; SD layout needs a transpose. + if not is_conv_state_dim_first(): + conv_state_q = conv_state_q.transpose(-1, -2) + conv_state_k = conv_state_k.transpose(-1, -2) + conv_state_v = conv_state_v.transpose(-1, -2) q_conv_weights = self.q_conv1d.weight.view( self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2) diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index c5ea14cab..9b95e00d2 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -41,6 +41,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weigh from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator, + is_conv_state_dim_first, ) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, @@ -699,7 +700,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 self_kv_cache = self.kv_cache - conv_state = self_kv_cache[0].transpose(-1, -2) + # conv_state must be (..., dim, width-1) for the conv kernels. + # DS layout stores it that way directly; SD layout needs a transpose. + conv_state = ( + self_kv_cache[0] + if is_conv_state_dim_first() + else self_kv_cache[0].transpose(-1, -2) + ) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens @@ -914,7 +921,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): """ non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 self_kv_cache = self.kv_cache - conv_state = self_kv_cache[0].transpose(-1, -2) + # conv_state must be (..., dim, width-1) for the conv kernels. + # DS layout stores it that way directly; SD layout needs a transpose. + conv_state = ( + self_kv_cache[0] + if is_conv_state_dim_first() + else self_kv_cache[0].transpose(-1, -2) + ) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index d79af2e27..fd83d4b83 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator, + is_conv_state_dim_first, ) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, @@ -267,9 +268,12 @@ class MambaMixer(MambaBase, PluggableLayer): query_start_loc_p = attn_metadata.query_start_loc_p state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_d = attn_metadata.state_indices_tensor_d - self_kv_cache = self.kv_cache - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] + conv_state = ( + self.kv_cache[0] + if is_conv_state_dim_first() + else self.kv_cache[0].transpose(-1, -2) + ) + ssm_state = self.kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p last_chunk_indices_p = attn_metadata.last_chunk_indices_p diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 041405b05..01ea3fdca 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator, + is_conv_state_dim_first, ) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, @@ -575,10 +576,15 @@ class MambaMixer2(MambaBase, PluggableLayer): assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, Mamba2AttentionMetadata) - self_kv_cache = self.kv_cache - # conv_state = (..., dim, width-1) yet contiguous along 'dim' - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] + # conv_state must be (..., dim, width-1) for the conv kernels. + # DS layout stores it that way directly; SD layout needs a + # transpose (which keeps dim contiguous via stride tricks). + conv_state = ( + self.kv_cache[0] + if is_conv_state_dim_first() + else self.kv_cache[0].transpose(-1, -2) + ) + ssm_state = self.kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 1f6751f6c..a5a30502b 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -1,20 +1,52 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools from collections.abc import Callable from dataclasses import dataclass -from typing import TypeAlias +from typing import Literal, TypeAlias import torch +import vllm.envs as envs from vllm.config.cache import MambaDType from vllm.config.model import ModelDType from vllm.distributed import divide +from vllm.logger import init_logger from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype, ) +logger = init_logger(__name__) + +ConvStateLayoutType = Literal["SD", "DS"] + + +@functools.lru_cache +def get_conv_state_layout() -> ConvStateLayoutType: + """Return the SSM conv state layout. + + SD = (state_len, dim) — dim is the innermost contiguous dimension. + DS = (dim, state_len) — TP-sharded dim is on dim-1 (like HND for KV + cache), consistent with SSM temporal state layout. + """ + layout: ConvStateLayoutType | None = envs.VLLM_SSM_CONV_STATE_LAYOUT + if layout is not None: + logger.info_once( + "VLLM_SSM_CONV_STATE_LAYOUT env detected. " + "Setting SSM conv state layout to %s.", + layout, + ) + return layout + + return "SD" + + +def is_conv_state_dim_first() -> bool: + """True when the conv state is stored as (dim, state_len) per block.""" + return get_conv_state_layout() == "DS" + class MambaStateDtypeCalculator: @classmethod @@ -107,6 +139,13 @@ class MambaStateShapeCalculator: state_shape = (num_heads // tp_size, head_dim, head_dim) return (state_shape,) + @staticmethod + def _orient_conv_shape(dim: int, state_len: int) -> tuple[int, int]: + """Return (dim, state_len) for DS layout, (state_len, dim) for SD.""" + if is_conv_state_dim_first(): + return (dim, state_len) + return (state_len, dim) + @classmethod def mamba1_state_shape( cls, @@ -115,12 +154,11 @@ class MambaStateShapeCalculator: state_size: int, conv_kernel: int, ) -> tuple[tuple[int, int], tuple[int, int]]: - conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1) + conv_dim = divide(intermediate_size, tp_world_size) + conv_state_shape = cls._orient_conv_shape(conv_dim, conv_kernel - 1) temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size) - conv_state_shape = conv_state_shape[1], conv_state_shape[0] - return conv_state_shape, temporal_state_shape @classmethod @@ -141,8 +179,9 @@ class MambaStateShapeCalculator: # heads and n_groups are TP-ed conv_dim = intermediate_size + 2 * n_groups * state_size - # contiguous along 'dim' axis - conv_state_shape = (conv_kernel - 1 + num_spec, divide(conv_dim, tp_world_size)) + conv_state_shape = cls._orient_conv_shape( + divide(conv_dim, tp_world_size), conv_kernel - 1 + num_spec + ) # These are not TP-ed as they depend on A, dt_bias, D # - they are typically small @@ -158,7 +197,7 @@ class MambaStateShapeCalculator: conv_kernel: int, ) -> tuple[tuple[int, int]]: conv_dim = divide(intermediate_size, tp_world_size) - conv_state_shape = (conv_kernel - 1, conv_dim) + conv_state_shape = cls._orient_conv_shape(conv_dim, conv_kernel - 1) return (conv_state_shape,) @classmethod @@ -185,13 +224,11 @@ class MambaStateShapeCalculator: num_spec: int = 0, ): conv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads - conv_state_shape = ( + conv_state_shape = cls._orient_conv_shape( divide(conv_dim, tp_world_size), conv_kernel_size - 1 + num_spec, ) - conv_state_shape = conv_state_shape[1], conv_state_shape[0] - temporal_state_shape = ( divide(num_v_heads, tp_world_size), head_v_dim, @@ -218,12 +255,13 @@ class MambaStateShapeCalculator: proj_size = num_heads * head_dim proj_k_size = num_k_heads * head_k_dim - conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1) - conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1) + conv_state_shape = cls._orient_conv_shape( + divide(proj_size, tp_world_size), conv_kernel_size - 1 + ) + conv_state_k_shape = cls._orient_conv_shape( + divide(proj_k_size, tp_world_size), conv_kernel_size - 1 + ) recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim) - - conv_state_shape = conv_state_shape[1], conv_state_shape[0] - conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0] return ( conv_state_shape, conv_state_k_shape, @@ -267,9 +305,27 @@ def get_conv_copy_spec( cur_block_idx: int, num_accepted_tokens: int, ) -> MambaCopySpec: - """Return a MambaCopySpec for copying a convolutional state slice.""" + """Return a MambaCopySpec for copying a convolutional state slice. + + Works for both SD layout ``(num_blocks, state_len, dim)`` and + DS layout ``(num_blocks, dim, state_len)``. + """ src_block_id = block_ids[cur_block_idx] - src_state = state[src_block_id, num_accepted_tokens - 1 :] + offset = num_accepted_tokens - 1 + if is_conv_state_dim_first(): + # DS layout: (num_blocks, dim, state_len) — state_len is last. + if offset > 0: + # Slicing along the last dim yields a non-contiguous view + # because features (dim) are strided by state_len. + raise NotImplementedError( + "DS conv state layout does not yet support speculative " + "decoding with mamba_cache_mode='align' " + "(num_accepted_tokens > 1)." + ) + src_state = state[src_block_id] + else: + # SD layout: (num_blocks, state_len, dim) — dim contiguous. + src_state = state[src_block_id, offset:] return MambaCopySpec( start_addr=src_state.data_ptr(), num_elements=src_state.numel() ) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index a8efdc9f1..1160105ad 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -592,7 +592,6 @@ def causal_conv1d_fn( stride_istate_seq = conv_states.stride(0) stride_istate_dim = conv_states.stride(1) stride_istate_token = conv_states.stride(2) - assert stride_istate_dim == 1 if out.dim() == 2: stride_o_dim = out.stride(0) stride_o_token = out.stride(1) @@ -1149,9 +1148,6 @@ def causal_conv1d_update( if validate_data: assert dim == weight.size(0) - assert conv_state.stride(-2) == 1, ( - f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" - ) assert state_len >= width - 1 # when above happens, we don't shift-left to keep any records in conv_state assert dim == conv_state.size(1) diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index d36dc0096..11e9b590f 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -17,6 +17,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator, + is_conv_state_dim_first, ) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, @@ -117,8 +118,11 @@ class ShortConv(MambaBase, CustomOp): assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, ShortConvAttentionMetadata) - self_kv_cache = self.kv_cache - conv_state = self_kv_cache[0].transpose(-1, -2) + conv_state = ( + self.kv_cache[0] + if is_conv_state_dim_first() + else self.kv_cache[0].transpose(-1, -2) + ) state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_d = attn_metadata.state_indices_tensor_d has_initial_states_p = attn_metadata.has_initial_states_p diff --git a/vllm/model_executor/models/olmo_hybrid.py b/vllm/model_executor/models/olmo_hybrid.py index 97e56b3ff..d070132fc 100644 --- a/vllm/model_executor/models/olmo_hybrid.py +++ b/vllm/model_executor/models/olmo_hybrid.py @@ -68,6 +68,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, + is_conv_state_dim_first, ) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, @@ -429,7 +430,13 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase): spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor self_kv_cache = self.kv_cache - conv_state = self_kv_cache[0].transpose(-1, -2) + # conv_state must be (..., dim, width-1) for the conv kernels. + # DS layout stores it that way directly; SD layout needs a transpose. + conv_state = ( + self_kv_cache[0] + if is_conv_state_dim_first() + else self_kv_cache[0].transpose(-1, -2) + ) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 44b120774..ce7acc1cb 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -32,6 +32,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, + is_conv_state_dim_first, ) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, @@ -266,7 +267,13 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer): assert isinstance(attn_metadata, Mamba2AttentionMetadata) self_kv_cache = self.kv_cache # conv_state = (..., dim, width-1) yet contiguous along 'dim' - conv_state = self_kv_cache[0].transpose(-1, -2) + # conv_state must be (..., dim, width-1) for the conv kernels. + # DS layout stores it that way directly; SD layout needs a transpose. + conv_state = ( + self_kv_cache[0] + if is_conv_state_dim_first() + else self_kv_cache[0].transpose(-1, -2) + ) ssm_state = self_kv_cache[1] state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_d = attn_metadata.state_indices_tensor_d