diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 8057a8d32..e5e73625f 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -356,7 +356,7 @@ def _chunk_scan_fwd_kernel( ) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. - cb *= fast_exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + cb *= fast_exp(tl.minimum(dA_cs_m[:, None] - dA_cs_k[None, :], 0.0)) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 37532e6db..8402d5291 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -280,7 +280,7 @@ def _chunk_state_fwd_kernel( dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( tl.float32 ) - scale = fast_exp(dA_cs_last - dA_cs_k) * dt_k + scale = fast_exp(tl.minimum(dA_cs_last - dA_cs_k, 0.0)) * dt_k b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b)