[Bugfix] clamp dA_cumsum differences to prevent Inf in Mamba2 SSD kernels (#37501)

Signed-off-by: Jingu Kang <jg.k@navercorp.com>
This commit is contained in:
Jingu Kang
2026-04-01 00:35:51 +09:00
committed by GitHub
parent 757068dc65
commit f1ff50c86c
2 changed files with 2 additions and 2 deletions

View File

@@ -356,7 +356,7 @@ def _chunk_scan_fwd_kernel(
)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
cb *= fast_exp(dA_cs_m[:, None] - dA_cs_k[None, :])
cb *= fast_exp(tl.minimum(dA_cs_m[:, None] - dA_cs_k[None, :], 0.0))
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
cb *= dt_k
if IS_CAUSAL:

View File

@@ -280,7 +280,7 @@ def _chunk_state_fwd_kernel(
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
tl.float32
)
scale = fast_exp(dA_cs_last - dA_cs_k) * dt_k
scale = fast_exp(tl.minimum(dA_cs_last - dA_cs_k, 0.0)) * dt_k
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)