diff --git a/tests/kernels/test_fused_sigmoid_gating_delta_rule.py b/tests/kernels/test_fused_sigmoid_gating_delta_rule.py new file mode 100644 index 000000000..2b03e83c3 --- /dev/null +++ b/tests/kernels/test_fused_sigmoid_gating_delta_rule.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.fla.ops import ( + fused_recurrent_gated_delta_rule, + fused_sigmoid_gating_delta_rule_update, +) +from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_random_seed + +DEVICE = current_platform.device_type + + +@pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("num_reqs", [1, 2, 4]) +@pytest.mark.parametrize("num_k_heads", [16]) +@pytest.mark.parametrize("num_v_heads", [32]) +@pytest.mark.parametrize("head_k_dim", [128]) +@pytest.mark.parametrize("head_v_dim", [128]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_fused_sigmoid_gating_delta_rule_update_non_spec( + tp_size: int, + num_reqs: int, + num_k_heads: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + dtype: torch.dtype, +) -> None: + torch.set_default_device(DEVICE) + set_random_seed(0) + key_dim = head_k_dim * num_k_heads + value_dim = head_v_dim * num_v_heads + mixed_qkv_dim = (key_dim * 2 + value_dim) // tp_size + seq_len = 1 # seq_len is 1 for decode + num_tokens = num_reqs * seq_len + total_entries = num_tokens * 2 + + mixed_qkv = torch.rand(num_tokens, mixed_qkv_dim, dtype=dtype) + query, key, value = torch.split( + mixed_qkv, + [ + key_dim // tp_size, + key_dim // tp_size, + value_dim // tp_size, + ], + dim=-1, + ) + query = query.view(1, num_tokens, num_k_heads, head_k_dim) + key = key.view(1, num_tokens, num_k_heads, head_k_dim) + value = value.view(1, num_tokens, num_v_heads, head_v_dim) + + A_log = torch.rand(num_v_heads // tp_size, dtype=dtype) + dt_bias = torch.rand(num_v_heads // tp_size, dtype=dtype) + a = torch.rand(num_tokens, num_v_heads, dtype=dtype) + b = torch.rand(num_tokens, num_v_heads, dtype=dtype) + ssm_state = torch.rand( + total_entries, num_v_heads, head_k_dim, head_v_dim, dtype=dtype + ) + state_indices = torch.randperm(total_entries, dtype=torch.int32)[:num_tokens] + cu_seqlens = torch.arange(0, num_tokens + 1, dtype=torch.int32) + + beta = b.sigmoid() + g = -A_log.float().exp() * F.softplus(a.float() + dt_bias) + core_attn_out_ref, last_recurrent_state_ref = fused_recurrent_gated_delta_rule( + q=query, + k=key, + v=value, + g=g.unsqueeze(0), + beta=beta.unsqueeze(0), + initial_state=ssm_state.clone(), + inplace_final_state=True, + ssm_state_indices=state_indices, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True, + ) + + core_attn_out, last_recurrent_state = fused_sigmoid_gating_delta_rule_update( + A_log=A_log, + a=a, + b=b, + dt_bias=dt_bias, + q=query, + k=key, + v=value, + initial_state=ssm_state, + inplace_final_state=True, + ssm_state_indices=state_indices, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True, + ) + + torch.testing.assert_close(core_attn_out, core_attn_out_ref, atol=1e-2, rtol=1e-2) + torch.testing.assert_close( + last_recurrent_state, last_recurrent_state_ref, atol=1e-2, rtol=1e-2 + ) + + +@pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("num_reqs", [1, 2, 4]) +@pytest.mark.parametrize("num_k_heads", [16]) +@pytest.mark.parametrize("num_v_heads", [32]) +@pytest.mark.parametrize("head_k_dim", [128]) +@pytest.mark.parametrize("head_v_dim", [128]) +@pytest.mark.parametrize("num_speculative_tokens", [1, 3]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_fused_sigmoid_gating_delta_rule_update_spec( + tp_size: int, + num_reqs: int, + num_k_heads: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + num_speculative_tokens: int, + dtype: torch.dtype, +) -> None: + torch.set_default_device(DEVICE) + set_random_seed(0) + key_dim = head_k_dim * num_k_heads + value_dim = head_v_dim * num_v_heads + mixed_qkv_dim = (key_dim * 2 + value_dim) // tp_size + num_tokens = num_reqs * (num_speculative_tokens + 1) + total_entries = num_tokens * 2 + + mixed_qkv = torch.rand(num_tokens, mixed_qkv_dim, dtype=dtype) + query, key, value = torch.split( + mixed_qkv, + [ + key_dim // tp_size, + key_dim // tp_size, + value_dim // tp_size, + ], + dim=-1, + ) + query = query.view(1, num_tokens, num_k_heads, head_k_dim) + key = key.view(1, num_tokens, num_k_heads, head_k_dim) + value = value.view(1, num_tokens, num_v_heads, head_v_dim) + + A_log = torch.rand(num_v_heads // tp_size, dtype=dtype) + dt_bias = torch.rand(num_v_heads // tp_size, dtype=dtype) + a = torch.rand(num_tokens, num_v_heads, dtype=dtype) + b = torch.rand(num_tokens, num_v_heads, dtype=dtype) + ssm_state = torch.rand( + total_entries, num_v_heads, head_k_dim, head_v_dim, dtype=dtype + ) + state_indices = torch.randperm( + total_entries, + dtype=torch.int32, + )[:num_tokens].view(num_reqs, num_speculative_tokens + 1) + num_accepted_tokens = torch.randint( + 1, num_speculative_tokens + 1, (num_reqs,), dtype=torch.int32 + ) + cu_seqlens = torch.arange( + 0, num_tokens + 1, num_speculative_tokens + 1, dtype=torch.int32 + ) + + beta = b.sigmoid() + g = -A_log.float().exp() * F.softplus(a.float() + dt_bias) + core_attn_out_ref, last_recurrent_state_ref = fused_recurrent_gated_delta_rule( + q=query, + k=key, + v=value, + g=g.unsqueeze(0), + beta=beta.unsqueeze(0), + initial_state=ssm_state.clone(), + inplace_final_state=True, + ssm_state_indices=state_indices, + cu_seqlens=cu_seqlens, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + ) + + core_attn_out, last_recurrent_state = fused_sigmoid_gating_delta_rule_update( + A_log=A_log, + a=a, + b=b, + dt_bias=dt_bias, + q=query, + k=key, + v=value, + initial_state=ssm_state, + inplace_final_state=True, + ssm_state_indices=state_indices, + cu_seqlens=cu_seqlens, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + ) + + torch.testing.assert_close(core_attn_out, core_attn_out_ref, atol=1e-2, rtol=1e-2) + torch.testing.assert_close( + last_recurrent_state, last_recurrent_state_ref, atol=1e-2, rtol=1e-2 + ) diff --git a/vllm/model_executor/layers/fla/ops/__init__.py b/vllm/model_executor/layers/fla/ops/__init__.py index c19cc14ba..06bd38d4c 100644 --- a/vllm/model_executor/layers/fla/ops/__init__.py +++ b/vllm/model_executor/layers/fla/ops/__init__.py @@ -8,10 +8,12 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang from .chunk import chunk_gated_delta_rule from .fused_recurrent import fused_recurrent_gated_delta_rule +from .fused_sigmoid_gating import fused_sigmoid_gating_delta_rule_update from .layernorm_guard import RMSNormGated __all__ = [ "RMSNormGated", "chunk_gated_delta_rule", "fused_recurrent_gated_delta_rule", + "fused_sigmoid_gating_delta_rule_update", ] diff --git a/vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py b/vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py new file mode 100644 index 000000000..414891fd8 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py @@ -0,0 +1,279 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch + +from vllm.triton_utils import tl, triton + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) +def fused_sigmoid_gating_delta_rule_update_kernel( + A_log, + a, + b, + dt_bias, + beta, + threshold, + q, + k, + v, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + IS_KDA: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + + p_A_log = A_log + i_hv + if not IS_KDA: + p_a = a + bos * HV + i_hv + p_dt_bias = dt_bias + i_hv + else: + p_a = a + (bos * HV + i_hv) * K + o_k + p_dt_bias = dt_bias + i_hv * K + o_k + + p_b = b + bos * HV + i_hv + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_v[:, None] & mask_k[None, :] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + # Load state index and check for PAD_SLOT_ID (-1) + state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( + tl.int64 + ) + # Skip if state index is invalid (PAD_SLOT_ID = -1) + if state_idx < 0: + return + p_h0 = h0 + state_idx * stride_init_state_token + else: + p_h0 = h0 + bos * HV * V * K + p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b).to(tl.float32) + + # If the model is loaded in fp16, without the .float() here, A might be -inf + x = tl.load(p_a).to(tl.float32) + tl.load(p_dt_bias).to(tl.float32) + softplus_x = tl.where( + beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x + ) + b_g = -tl.exp(tl.load(p_A_log).to(tl.float32)) * softplus_x + + # compute beta_output = sigmoid(b) + b_beta = tl.sigmoid(b_b.to(tl.float32)) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q * (tl.rsqrt(tl.sum(b_q * b_q) + 1e-6)) + b_k = b_k * (tl.rsqrt(tl.sum(b_k * b_k) + 1e-6)) + b_q = b_q * scale + # [BV, BK] + if not IS_KDA: + b_h *= tl.exp(b_g) + else: + b_h *= tl.exp(b_g[None, :]) + # [BV] + b_v -= tl.sum(b_h * b_k[None, :], 1) + b_v *= b_beta + # [BV, BK] + b_h += b_v[:, None] * b_k[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[None, :], 1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + # Load state index and check for PAD_SLOT_ID (-1) + final_state_idx = tl.load( + ssm_state_indices + i_n * stride_indices_seq + i_t + ).to(tl.int64) + # Only store if state index is valid (not PAD_SLOT_ID) + if final_state_idx >= 0: + p_ht = ht + final_state_idx * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + # Update pointers for next timestep + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + p_b += HV + p_a += HV + + +def fused_sigmoid_gating_delta_rule_update( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + is_kda: bool = False, +): + """ + Fused triton implementation of sigmoid gating delta rule update. + This function uses a single fused kernel that combines both sigmoid gating + computation and the recurrent delta rule update for better performance. + """ + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 4 + + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]}" + f" when using `cu_seqlens`. Please flatten variable-length" + f" inputs before processing." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + grid = (NK, NV, N * HV) + fused_sigmoid_gating_delta_rule_update_kernel[grid]( + A_log=A_log, + a=a.contiguous(), + b=b.contiguous(), + dt_bias=dt_bias, + beta=beta, + threshold=threshold, + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + INPLACE_FINAL_STATE=inplace_final_state, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + IS_KDA=is_kda, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 7f1386d7b..9eba97c26 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -34,7 +34,7 @@ from vllm.model_executor.layers.fla.ops import ( chunk_gated_delta_rule as fla_chunk_gated_delta_rule, ) from vllm.model_executor.layers.fla.ops import ( - fused_recurrent_gated_delta_rule, + fused_sigmoid_gating_delta_rule_update, ) from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd from vllm.model_executor.layers.fused_moe import SharedFusedMoE @@ -731,41 +731,40 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): mixed_qkv_non_spec ) - g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) - - if spec_sequence_masks is not None: - if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: - g_spec = g - beta_spec = beta - g_non_spec = None - beta_non_spec = None - else: - g_spec = g.index_select(1, spec_token_indx) - beta_spec = beta.index_select(1, spec_token_indx) + if attn_metadata.num_prefills > 0: + g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) + if spec_sequence_masks is not None: g_non_spec = g.index_select(1, non_spec_token_indx) beta_non_spec = beta.index_select(1, non_spec_token_indx) + else: + g_non_spec = g + beta_non_spec = beta else: - g_spec = None - beta_spec = None - g_non_spec = g - beta_non_spec = beta + g_non_spec = None + beta_non_spec = None # 2. Recurrent attention # 2.1: Process the multi-query part if spec_sequence_masks is not None: - core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( - q=query_spec, - k=key_spec, - v=value_spec, - g=g_spec, - beta=beta_spec, - initial_state=ssm_state, - inplace_final_state=True, - cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], - ssm_state_indices=spec_state_indices_tensor, - num_accepted_tokens=num_accepted_tokens, - use_qk_l2norm_in_kernel=True, + core_attn_out_spec, last_recurrent_state = ( + fused_sigmoid_gating_delta_rule_update( + A_log=self.A_log, + a=a, + b=b, + dt_bias=self.dt_bias, + q=query_spec, + k=key_spec, + v=value_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[ + : attn_metadata.num_spec_decodes + 1 + ], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + ) ) else: core_attn_out_spec, last_recurrent_state = None, None @@ -794,12 +793,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ) elif attn_metadata.num_decodes > 0: core_attn_out_non_spec, last_recurrent_state = ( - fused_recurrent_gated_delta_rule( + fused_sigmoid_gating_delta_rule_update( + A_log=self.A_log, + a=a, + b=b, + dt_bias=self.dt_bias, q=query_non_spec, k=key_non_spec, v=value_non_spec, - g=g_non_spec, - beta=beta_non_spec, initial_state=ssm_state, inplace_final_state=True, cu_seqlens=non_spec_query_start_loc[