From ccf02fcbaebb1a5b59dfc6c7cb64aa7cc489f04c Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 14 Mar 2025 23:45:42 -0400 Subject: [PATCH] =?UTF-8?q?Revert=20"[Model]=20Mamba2=20Prefill=20Performa?= =?UTF-8?q?nce=20Tweaks:=20Fixing=20Flurry=20of=20U=E2=80=A6=20(#14848)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../layers/mamba/mamba_mixer2.py | 30 +++++-------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 5b19e3f35..b53a540ed 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -466,17 +466,10 @@ class MambaMixer2(CustomOp): if has_prefill: initial_states = None - - if has_initial_states is not None and torch.any( - has_initial_states): - - # vectorized ssm_state zero init - batched_zero_init_func = torch.vmap( - lambda idx: mamba_cache_params.ssm_state[idx].zero_()) - batched_zero_init_func( - mamba_cache_params. - state_indices_tensor[~has_initial_states].unsqueeze( - dim=-1), ) + if has_initial_states is not None and any(has_initial_states): + for idx in mamba_cache_params.state_indices_tensor[ + ~has_initial_states]: + mamba_cache_params.ssm_state[idx].zero_() initial_states = mamba_cache_params.ssm_state[ mamba_cache_params.state_indices_tensor] @@ -500,17 +493,10 @@ class MambaMixer2(CustomOp): dt_limit=(0.0, float("inf")), ) - # vectorized ssm state update using vmap - # the 1d state_indices_tensor needs to be unsqueezed to avoid vmap - # limitation which doesn't allow use of `item()` - # Note: the lambda capture can happen where ssm_state is initialized - # instead of here - batched_copy = torch.vmap( - lambda idx, source_state: mamba_cache_params.ssm_state[ - idx].copy_(source_state)) - batched_copy( - mamba_cache_params.state_indices_tensor.unsqueeze(dim=-1), - varlen_state) + # update ssm states + # - varlen state is a (batch, nheads, headdim, dstate) tensor + for i, idx in enumerate(mamba_cache_params.state_indices_tensor): + mamba_cache_params.ssm_state[idx].copy_(varlen_state[i]) # - reshape hidden_states = scan_output.view(seq_len, -1)