diff --git a/tests/kernels/test_fused_gdn_post_conv.py b/tests/kernels/test_fused_gdn_post_conv.py new file mode 100644 index 000000000..ffc8ce281 --- /dev/null +++ b/tests/kernels/test_fused_gdn_post_conv.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for fused_gdn_prefill_post_conv kernel. + +Verifies that the fused kernel matches the reference: + split → rearrange → contiguous → l2norm → gating +""" + +import pytest +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.fla.ops.fused_gdn_prefill_post_conv import ( + fused_post_conv_prep, +) + + +def reference_post_conv( + conv_output: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + H: int, + K: int, + V: int, + apply_l2norm: bool = True, + output_g_exp: bool = False, +): + """Reference implementation using individual ops.""" + L = conv_output.shape[0] + HV = A_log.shape[0] + + # Split + q_flat, k_flat, v_flat = torch.split(conv_output, [H * K, H * K, HV * V], dim=-1) + + # Rearrange + contiguous + q = q_flat.view(L, H, K).contiguous() + k = k_flat.view(L, H, K).contiguous() + v = v_flat.view(L, HV, V).contiguous() + + # L2 norm + if apply_l2norm: + q = F.normalize(q.float(), p=2, dim=-1, eps=1e-6).to(conv_output.dtype) + k = F.normalize(k.float(), p=2, dim=-1, eps=1e-6).to(conv_output.dtype) + + # Gating + x = a.float() + dt_bias.float() + sp = F.softplus(x, beta=1.0, threshold=20.0) + g = -torch.exp(A_log.float()) * sp + + if output_g_exp: + g = torch.exp(g) + + beta_out = torch.sigmoid(b.float()) + + return q, k, v, g, beta_out + + +# Qwen3.5-35B config: H=16, HV=32, K=128, V=128 +# Qwen3.5-397B config: H=16, HV=64, K=128, V=128 +@pytest.mark.parametrize( + "H, HV, K, V", + [ + (16, 32, 128, 128), # 35B + (16, 64, 128, 128), # 397B + (4, 8, 64, 64), # small + ], +) +@pytest.mark.parametrize("L", [1, 16, 128, 512, 2048]) +@pytest.mark.parametrize("apply_l2norm", [True, False]) +@pytest.mark.parametrize("output_g_exp", [True, False]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_fused_post_conv_correctness(H, HV, K, V, L, apply_l2norm, output_g_exp, dtype): + """Test fused kernel matches reference for all configs.""" + torch.manual_seed(42) + device = "cuda" + qkv_dim = 2 * H * K + HV * V + + conv_output = torch.randn(L, qkv_dim, dtype=dtype, device=device) + a = torch.randn(L, HV, dtype=dtype, device=device) + b = torch.randn(L, HV, dtype=dtype, device=device) + A_log = torch.randn(HV, dtype=torch.float32, device=device) - 2.0 + dt_bias = torch.randn(HV, dtype=torch.float32, device=device) * 0.1 + + # Reference + ref_q, ref_k, ref_v, ref_g, ref_beta = reference_post_conv( + conv_output, + a, + b, + A_log, + dt_bias, + H, + K, + V, + apply_l2norm, + output_g_exp, + ) + + # Fused kernel + fused_q, fused_k, fused_v, fused_g, fused_beta = fused_post_conv_prep( + conv_output, + a, + b, + A_log, + dt_bias, + num_k_heads=H, + head_k_dim=K, + head_v_dim=V, + apply_l2norm=apply_l2norm, + output_g_exp=output_g_exp, + ) + + # Check shapes + assert fused_q.shape == (L, H, K), f"q shape: {fused_q.shape}" + assert fused_k.shape == (L, H, K), f"k shape: {fused_k.shape}" + assert fused_v.shape == (L, HV, V), f"v shape: {fused_v.shape}" + assert fused_g.shape == (L, HV), f"g shape: {fused_g.shape}" + assert fused_beta.shape == (L, HV), f"beta shape: {fused_beta.shape}" + + # Check dtypes + assert fused_q.dtype == dtype + assert fused_k.dtype == dtype + assert fused_v.dtype == dtype + assert fused_g.dtype == torch.float32 + assert fused_beta.dtype == torch.float32 + + # Check contiguity + assert fused_q.is_contiguous() + assert fused_k.is_contiguous() + assert fused_v.is_contiguous() + + # Check values + atol_qkv = 1e-2 if apply_l2norm else 1e-3 + rtol_qkv = 1e-2 if apply_l2norm else 1e-3 + + torch.testing.assert_close(fused_q, ref_q, atol=atol_qkv, rtol=rtol_qkv) + torch.testing.assert_close(fused_k, ref_k, atol=atol_qkv, rtol=rtol_qkv) + torch.testing.assert_close(fused_v, ref_v, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(fused_g, ref_g, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(fused_beta, ref_beta, atol=1e-4, rtol=1e-4) + + +@pytest.mark.parametrize("L", [1, 64, 256]) +def test_fused_post_conv_sanity(L): + """Sanity checks: no NaN, unit-norm q/k, beta in (0,1).""" + torch.manual_seed(0) + device = "cuda" + H, HV, K, V = 16, 32, 128, 128 + qkv_dim = 2 * H * K + HV * V + + conv_output = torch.randn(L, qkv_dim, dtype=torch.bfloat16, device=device) + a = torch.randn(L, HV, dtype=torch.bfloat16, device=device) + b = torch.randn(L, HV, dtype=torch.bfloat16, device=device) + A_log = torch.randn(HV, dtype=torch.float32, device=device) - 2.0 + dt_bias = torch.randn(HV, dtype=torch.float32, device=device) + + q, k, v, g, beta = fused_post_conv_prep( + conv_output, + a, + b, + A_log, + dt_bias, + num_k_heads=H, + head_k_dim=K, + head_v_dim=V, + ) + + # Basic sanity + assert not torch.isnan(q).any(), "NaN in q" + assert not torch.isnan(k).any(), "NaN in k" + assert not torch.isnan(v).any(), "NaN in v" + assert not torch.isnan(g).any(), "NaN in g" + assert not torch.isnan(beta).any(), "NaN in beta" + + # L2 norm check: each head vector should have unit norm + q_norms = torch.norm(q.float(), dim=-1) + k_norms = torch.norm(k.float(), dim=-1) + torch.testing.assert_close(q_norms, torch.ones_like(q_norms), atol=1e-3, rtol=1e-3) + torch.testing.assert_close(k_norms, torch.ones_like(k_norms), atol=1e-3, rtol=1e-3) + + # Beta should be in (0, 1) + assert (beta >= 0).all() and (beta <= 1).all(), "beta out of range" + + +def test_fused_post_conv_l0(): + """Test L=0 edge case.""" + device = "cuda" + H, HV, K, V = 16, 32, 128, 128 + qkv_dim = 2 * H * K + HV * V + + conv_output = torch.empty(0, qkv_dim, dtype=torch.bfloat16, device=device) + a = torch.empty(0, HV, dtype=torch.bfloat16, device=device) + b = torch.empty(0, HV, dtype=torch.bfloat16, device=device) + A_log = torch.randn(HV, dtype=torch.float32, device=device) + dt_bias = torch.randn(HV, dtype=torch.float32, device=device) + + q, k, v, g, beta = fused_post_conv_prep( + conv_output, + a, + b, + A_log, + dt_bias, + num_k_heads=H, + head_k_dim=K, + head_v_dim=V, + ) + assert q.shape == (0, H, K) + assert g.shape == (0, HV) diff --git a/vllm/model_executor/layers/fla/ops/__init__.py b/vllm/model_executor/layers/fla/ops/__init__.py index e52387a20..1942d8980 100644 --- a/vllm/model_executor/layers/fla/ops/__init__.py +++ b/vllm/model_executor/layers/fla/ops/__init__.py @@ -7,6 +7,7 @@ # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang from .chunk import chunk_gated_delta_rule +from .fused_gdn_prefill_post_conv import fused_post_conv_prep from .fused_recurrent import ( fused_recurrent_gated_delta_rule, fused_recurrent_gated_delta_rule_packed_decode, @@ -19,5 +20,6 @@ __all__ = [ "chunk_gated_delta_rule", "fused_recurrent_gated_delta_rule", "fused_recurrent_gated_delta_rule_packed_decode", + "fused_post_conv_prep", "fused_sigmoid_gating_delta_rule_update", ] diff --git a/vllm/model_executor/layers/fla/ops/fused_gdn_prefill_post_conv.py b/vllm/model_executor/layers/fla/ops/fused_gdn_prefill_post_conv.py new file mode 100644 index 000000000..4807c78e7 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/fused_gdn_prefill_post_conv.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Fused post-conv1d preparation for GDN prefill. + +Replaces the chain: + split → rearrange → contiguous * 3 → l2norm * 2 → gating +with a **single Triton kernel** that reads the conv'd mixed_qkv output +and writes directly to q/k/v/g/beta in the target contiguous layout. + +""" + +from __future__ import annotations + +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _fused_post_conv_kernel( + # ---- inputs ---- + mixed_qkv_ptr, # [L, qkv_dim] conv'd output (contiguous) + a_ptr, # [L, HV] + b_ptr, # [L, HV] + # ---- params ---- + A_log_ptr, # [HV] + dt_bias_ptr, # [HV] + # ---- outputs ---- + q_ptr, # [L, H, K] contiguous + k_ptr, # [L, H, K] contiguous + v_ptr, # [L, HV, V] contiguous + g_ptr, # [L, HV] float32 + beta_ptr, # [L, HV] float32 + # ---- strides ---- + stride_x_tok, # qkv_dim + stride_a_tok, # HV + stride_b_tok, # HV + stride_q_tok, # H * K + stride_k_tok, # H * K + stride_v_tok, # HV * V + # ---- dims ---- + L, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + APPLY_L2NORM: tl.constexpr, + L2NORM_EPS: tl.constexpr, + OUTPUT_G_EXP: tl.constexpr, + SOFTPLUS_THRESHOLD: tl.constexpr, + BLOCK_T: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + """Single fused kernel for post-conv1d preparation. + + Grid: (ceil(L, BLOCK_T), H + HV) + - program_id(1) in [0, H): Q/K head processing + l2norm + - program_id(1) in [H, H+HV): V head processing + gating + """ + i_tb = tl.program_id(0) + i_head = tl.program_id(1) + + HK: tl.constexpr = H * K + + offs_t = i_tb * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] + mask_t = offs_t < L + + if i_head < H: + # ============ Q/K head processing ============ + i_h = i_head + offs_k = tl.arange(0, BK) # [BK] + mask_k = offs_k < K + mask_2d = mask_t[:, None] & mask_k[None, :] # [BLOCK_T, BK] + + # Load Q features: mixed_qkv[t, i_h*K + k] + q_offsets = offs_t[:, None] * stride_x_tok + i_h * K + offs_k[None, :] + q_f32 = tl.load(mixed_qkv_ptr + q_offsets, mask=mask_2d, other=0).to(tl.float32) + + # Load K features: mixed_qkv[t, HK + i_h*K + k] + k_offsets = offs_t[:, None] * stride_x_tok + HK + i_h * K + offs_k[None, :] + k_f32 = tl.load(mixed_qkv_ptr + k_offsets, mask=mask_2d, other=0).to(tl.float32) + + if APPLY_L2NORM: + q_sq_sum = tl.sum(q_f32 * q_f32, axis=1) # [BLOCK_T] + q_inv = 1.0 / tl.sqrt(q_sq_sum + L2NORM_EPS) + q_f32 = q_f32 * q_inv[:, None] + + k_sq_sum = tl.sum(k_f32 * k_f32, axis=1) + k_inv = 1.0 / tl.sqrt(k_sq_sum + L2NORM_EPS) + k_f32 = k_f32 * k_inv[:, None] + + # Store Q + q_out = offs_t[:, None] * stride_q_tok + i_h * K + offs_k[None, :] + tl.store( + q_ptr + q_out, + q_f32.to(q_ptr.dtype.element_ty), + mask=mask_2d, + ) + + # Store K + k_out = offs_t[:, None] * stride_k_tok + i_h * K + offs_k[None, :] + tl.store( + k_ptr + k_out, + k_f32.to(k_ptr.dtype.element_ty), + mask=mask_2d, + ) + else: + # ============ V head + gating processing ============ + i_hv = i_head - H + offs_v = tl.arange(0, BV) # [BV] + mask_v = offs_v < V + mask_2d = mask_t[:, None] & mask_v[None, :] # [BLOCK_T, BV] + + V_OFFSET: tl.constexpr = 2 * H * K + + # Load V features: mixed_qkv[t, 2*H*K + i_hv*V + v] + v_offsets = ( + offs_t[:, None] * stride_x_tok + V_OFFSET + i_hv * V + offs_v[None, :] + ) + v_vals = tl.load(mixed_qkv_ptr + v_offsets, mask=mask_2d, other=0) + + # Store V + v_out = offs_t[:, None] * stride_v_tok + i_hv * V + offs_v[None, :] + tl.store(v_ptr + v_out, v_vals, mask=mask_2d) + + # Gating: one scalar per (token, v-head) + A_log_val = tl.load(A_log_ptr + i_hv).to(tl.float32) + dt_bias_val = tl.load(dt_bias_ptr + i_hv).to(tl.float32) + + a_offsets = offs_t * stride_a_tok + i_hv + b_offsets = offs_t * stride_b_tok + i_hv + a_vals = tl.load(a_ptr + a_offsets, mask=mask_t, other=0).to(tl.float32) + b_vals = tl.load(b_ptr + b_offsets, mask=mask_t, other=0).to(tl.float32) + + # g = -exp(A_log) * softplus(a + dt_bias) + x = a_vals + dt_bias_val + sp = tl.where(x > 0, x + tl.log(1.0 + tl.exp(-x)), tl.log(1.0 + tl.exp(x))) + sp = tl.where(x <= SOFTPLUS_THRESHOLD, sp, x) + g_vals = -tl.exp(A_log_val) * sp + + if OUTPUT_G_EXP: + g_vals = tl.exp(g_vals) + + beta_vals = tl.sigmoid(b_vals) + + gb_offsets = offs_t * HV + i_hv + tl.store(g_ptr + gb_offsets, g_vals, mask=mask_t) + tl.store(beta_ptr + gb_offsets, beta_vals, mask=mask_t) + + +def fused_post_conv_prep( + conv_output: torch.Tensor, # [L, qkv_dim] conv'd mixed_qkv + a: torch.Tensor, # [L, HV] + b: torch.Tensor, # [L, HV] + A_log: torch.Tensor, # [HV] + dt_bias: torch.Tensor, # [HV] + num_k_heads: int, + head_k_dim: int, + head_v_dim: int, + apply_l2norm: bool = True, + output_g_exp: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Fused post-conv1d prep: split + l2norm + gating in one kernel. + + Args: + conv_output: [L, qkv_dim] contiguous conv'd mixed_qkv + a: [L, HV] gating input + b: [L, HV] gating input + A_log: [HV] log decay parameter + dt_bias: [HV] dt bias parameter + num_k_heads: number of K heads (H) + head_k_dim: dimension per K head (K) + head_v_dim: dimension per V head (V) + apply_l2norm: whether to L2-normalize q and k + output_g_exp: if True, output exp(g) instead of g (for FlashInfer) + + Returns: + q: [L, H, K] contiguous, optionally l2-normalized + k: [L, H, K] contiguous, optionally l2-normalized + v: [L, HV, V] contiguous + g: [L, HV] float32 + beta: [L, HV] float32 + """ + L = conv_output.shape[0] + qkv_dim = conv_output.shape[1] + H = num_k_heads + K = head_k_dim + V = head_v_dim + HV = A_log.shape[0] + dtype = conv_output.dtype + device = conv_output.device + + assert qkv_dim == 2 * H * K + HV * V, ( + f"qkv_dim={qkv_dim} != 2*H*K + HV*V = {2 * H * K + HV * V}" + ) + + # Allocate outputs in target contiguous layout + q = torch.empty(L, H, K, dtype=dtype, device=device) + k = torch.empty(L, H, K, dtype=dtype, device=device) + v = torch.empty(L, HV, V, dtype=dtype, device=device) + g = torch.empty(L, HV, dtype=torch.float32, device=device) + beta = torch.empty(L, HV, dtype=torch.float32, device=device) + + if L == 0: + return q, k, v, g, beta + + # ---- Kernel config ---- + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + BLOCK_T = 16 # tokens per block + + # Single kernel: blocks [0,H) do Q/K, blocks [H, H+HV) do V+gating + grid = (triton.cdiv(L, BLOCK_T), H + HV) + _fused_post_conv_kernel[grid]( + mixed_qkv_ptr=conv_output, + a_ptr=a, + b_ptr=b, + A_log_ptr=A_log, + dt_bias_ptr=dt_bias, + q_ptr=q, + k_ptr=k, + v_ptr=v, + g_ptr=g, + beta_ptr=beta, + stride_x_tok=conv_output.stride(0), + stride_a_tok=a.stride(0), + stride_b_tok=b.stride(0), + stride_q_tok=q.stride(0), + stride_k_tok=k.stride(0), + stride_v_tok=v.stride(0), + L=L, + H=H, + HV=HV, + K=K, + V=V, + APPLY_L2NORM=apply_l2norm, + L2NORM_EPS=1e-6, + OUTPUT_G_EXP=output_g_exp, + SOFTPLUS_THRESHOLD=20.0, + BLOCK_T=BLOCK_T, + BK=BK, + BV=BV, + num_warps=4, + num_stages=2, + ) + + return q, k, v, g, beta diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index 2b952e10e..c5ea14cab 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -24,6 +24,7 @@ from vllm.model_executor.layers.fla.ops import ( chunk_gated_delta_rule as fla_chunk_gated_delta_rule, ) from vllm.model_executor.layers.fla.ops import ( + fused_post_conv_prep, fused_recurrent_gated_delta_rule_packed_decode, fused_sigmoid_gating_delta_rule_update, ) @@ -774,19 +775,44 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): mixed_qkv_non_spec = None query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) - query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( - mixed_qkv_non_spec - ) - if attn_metadata.num_prefills > 0: - g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) + assert mixed_qkv_non_spec is not None, ( + "mixed_qkv_non_spec must be provided for prefill path" + ) if spec_sequence_masks is not None: - g_non_spec = g.index_select(1, non_spec_token_indx) - beta_non_spec = beta.index_select(1, non_spec_token_indx) + a_non_spec = a.index_select(0, non_spec_token_indx) + b_non_spec = b.index_select(0, non_spec_token_indx) else: - g_non_spec = g - beta_non_spec = beta + a_non_spec = a + b_non_spec = b + + ( + query_non_spec, + key_non_spec, + value_non_spec, + g_non_spec, + beta_non_spec, + ) = fused_post_conv_prep( + conv_output=mixed_qkv_non_spec, + a=a_non_spec, + b=b_non_spec, + A_log=self.A_log, + dt_bias=self.dt_bias, + num_k_heads=self.num_k_heads // self.tp_size, + head_k_dim=self.head_k_dim, + head_v_dim=self.head_v_dim, + apply_l2norm=True, + output_g_exp=False, + ) + query_non_spec = query_non_spec.unsqueeze(0) + key_non_spec = key_non_spec.unsqueeze(0) + value_non_spec = value_non_spec.unsqueeze(0) + g_non_spec = g_non_spec.unsqueeze(0) + beta_non_spec = beta_non_spec.unsqueeze(0) else: + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( + mixed_qkv_non_spec + ) g_non_spec = None beta_non_spec = None @@ -832,7 +858,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): initial_state=initial_state, output_final_state=True, cu_seqlens=non_spec_query_start_loc, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=False, ) # Init cache ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(