Add the support for the qwen3 next model (a hybrid attention model). (#24526)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -14,7 +14,7 @@ import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
||||
from .op import exp, safe_exp
|
||||
from .op import exp
|
||||
from .utils import is_nvidia_hopper, use_cuda_graph
|
||||
|
||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
|
||||
@@ -175,12 +175,13 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
boundary_check=(0, 1))
|
||||
|
||||
if USE_G:
|
||||
m_t = (i_t * BT + tl.arange(0, BT)) < T
|
||||
last_idx = min((i_t + 1) * BT, T) - 1
|
||||
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
|
||||
p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None]
|
||||
b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
|
||||
b_g_last = exp(b_g_last)
|
||||
b_h1 = b_h1 * b_g_last
|
||||
if K > 64:
|
||||
|
||||
@@ -16,7 +16,7 @@ import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import exp, safe_exp
|
||||
from .op import exp
|
||||
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
|
||||
|
||||
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
|
||||
@@ -112,10 +112,11 @@ def chunk_fwd_kernel_o(
|
||||
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
b_o = b_o * exp(b_g)[:, None]
|
||||
b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
|
||||
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
m_A = o_i[:, None] >= o_i[None, :]
|
||||
o_t = i_t * BT + tl.arange(0, BT)
|
||||
m_t = o_t < T
|
||||
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
|
||||
b_A = tl.where(m_A, b_A, 0)
|
||||
|
||||
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
|
||||
@@ -14,7 +14,7 @@ import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import safe_exp
|
||||
from .op import exp
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
@@ -56,7 +56,8 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
o_t = tl.arange(0, BT)
|
||||
o_t = i_t * BT + tl.arange(0, BT)
|
||||
m_t = o_t < T
|
||||
|
||||
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
@@ -76,9 +77,10 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||
b_A = b_A * safe_exp(b_g_diff)
|
||||
b_A = b_A * exp(b_g_diff)
|
||||
|
||||
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
|
||||
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
|
||||
b_A = tl.where(m_A, b_A, 0)
|
||||
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
|
||||
(i_t * BT, 0), (BT, BT), (1, 0))
|
||||
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
@@ -116,8 +116,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
b_g = tl.load(p_g).to(tl.float32)
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
|
||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
|
||||
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
||||
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
||||
b_q = b_q * scale
|
||||
# [BK, BV]
|
||||
b_h *= exp(b_g)
|
||||
|
||||
@@ -78,7 +78,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
||||
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
|
||||
xmask = row_idx < M
|
||||
rindex = tl.arange(0, N)[None, :]
|
||||
xs = tl.load(X + (rindex + N * row_idx), None).to(tl.float32)
|
||||
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
|
||||
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
|
||||
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
|
||||
rsqrt = tl.rsqrt(square_sum + eps)
|
||||
|
||||
@@ -28,11 +28,6 @@ else:
|
||||
log2 = tl.log2
|
||||
|
||||
|
||||
@triton.jit
|
||||
def safe_exp(x):
|
||||
return exp(tl.where(x <= 0, x, float('-inf')))
|
||||
|
||||
|
||||
if not hasattr(tl, 'gather'):
|
||||
|
||||
@triton.jit
|
||||
|
||||
Reference in New Issue
Block a user