[Bugfix] Mamba2 SSD varlen bug fix initstates decay, improve test, assert chunk pwr 2 (#21783)
Signed-off-by: Rishi Astra <40644327+RishiAstra@users.noreply.github.com>
This commit is contained in:
@@ -290,10 +290,8 @@ def _chunk_scan_fwd_kernel(
|
||||
# get the cs at the offset boundary
|
||||
# - c_off == 0 is a passthrough
|
||||
dA_cs_m_boundary = tl.load(
|
||||
dA_cumsum_ptr +
|
||||
(pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1)
|
||||
and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)),
|
||||
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
if HAS_SEQ_IDX:
|
||||
|
||||
@@ -21,6 +21,10 @@ from .ssd_state_passing import _state_passing_fwd
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
||||
|
||||
|
||||
def is_int_pow_2(n):
|
||||
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
|
||||
|
||||
|
||||
def _mamba_chunk_scan_combined_fwd(x,
|
||||
dt,
|
||||
A,
|
||||
@@ -38,6 +42,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=None):
|
||||
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
|
||||
Reference in New Issue
Block a user