Add the support for the qwen3 next model (a hybrid attention model). (#24526)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Tao He
2025-09-11 15:32:09 +08:00
committed by GitHub
parent 2048c4e379
commit e93f4cc9e3
29 changed files with 2476 additions and 61 deletions

View File

@@ -14,7 +14,7 @@ import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices, prepare_chunk_offsets
from .op import exp, safe_exp
from .op import exp
from .utils import is_nvidia_hopper, use_cuda_graph
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
@@ -175,12 +175,13 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
boundary_check=(0, 1))
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)) < T
last_idx = min((i_t + 1) * BT, T) - 1
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ),
(i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None]
b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
b_g_last = exp(b_g_last)
b_h1 = b_h1 * b_g_last
if K > 64:

View File

@@ -16,7 +16,7 @@ import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import exp, safe_exp
from .op import exp
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
@@ -112,10 +112,11 @@ def chunk_fwd_kernel_o(
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
b_o = b_o * exp(b_g)[:, None]
b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
o_i = tl.arange(0, BT)
m_A = o_i[:, None] >= o_i[None, :]
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0)
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),

View File

@@ -14,7 +14,7 @@ import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import safe_exp
from .op import exp
@triton.heuristics({
@@ -56,7 +56,8 @@ def chunk_scaled_dot_kkt_fwd_kernel(
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_t = tl.arange(0, BT)
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
(i_t * BT, ), (BT, ), (0, ))
@@ -76,9 +77,10 @@ def chunk_scaled_dot_kkt_fwd_kernel(
(i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
b_g_diff = b_g[:, None] - b_g[None, :]
b_A = b_A * safe_exp(b_g_diff)
b_A = b_A * exp(b_g_diff)
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0)
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
(i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))

View File

@@ -116,8 +116,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
# [BK, BV]
b_h *= exp(b_g)

View File

@@ -78,7 +78,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
xmask = row_idx < M
rindex = tl.arange(0, N)[None, :]
xs = tl.load(X + (rindex + N * row_idx), None).to(tl.float32)
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
rsqrt = tl.rsqrt(square_sum + eps)

View File

@@ -28,11 +28,6 @@ else:
log2 = tl.log2
@triton.jit
def safe_exp(x):
return exp(tl.where(x <= 0, x, float('-inf')))
if not hasattr(tl, 'gather'):
@triton.jit