[Bugfix] clamp dA_cumsum differences to prevent Inf in Mamba2 SSD kernels (#37501)
Signed-off-by: Jingu Kang <jg.k@navercorp.com>
This commit is contained in:
@@ -356,7 +356,7 @@ def _chunk_scan_fwd_kernel(
|
||||
)
|
||||
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
|
||||
# So we don't need masking wrt seq_idx here.
|
||||
cb *= fast_exp(dA_cs_m[:, None] - dA_cs_k[None, :])
|
||||
cb *= fast_exp(tl.minimum(dA_cs_m[:, None] - dA_cs_k[None, :], 0.0))
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
|
||||
cb *= dt_k
|
||||
if IS_CAUSAL:
|
||||
|
||||
@@ -280,7 +280,7 @@ def _chunk_state_fwd_kernel(
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
scale = fast_exp(dA_cs_last - dA_cs_k) * dt_k
|
||||
scale = fast_exp(tl.minimum(dA_cs_last - dA_cs_k, 0.0)) * dt_k
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
|
||||
Reference in New Issue
Block a user