[Perf] fuse kernels in gdn (#37813)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
209
tests/kernels/test_fused_gdn_post_conv.py
Normal file
209
tests/kernels/test_fused_gdn_post_conv.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for fused_gdn_prefill_post_conv kernel.
|
||||
|
||||
Verifies that the fused kernel matches the reference:
|
||||
split → rearrange → contiguous → l2norm → gating
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.layers.fla.ops.fused_gdn_prefill_post_conv import (
|
||||
fused_post_conv_prep,
|
||||
)
|
||||
|
||||
|
||||
def reference_post_conv(
|
||||
conv_output: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
A_log: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
H: int,
|
||||
K: int,
|
||||
V: int,
|
||||
apply_l2norm: bool = True,
|
||||
output_g_exp: bool = False,
|
||||
):
|
||||
"""Reference implementation using individual ops."""
|
||||
L = conv_output.shape[0]
|
||||
HV = A_log.shape[0]
|
||||
|
||||
# Split
|
||||
q_flat, k_flat, v_flat = torch.split(conv_output, [H * K, H * K, HV * V], dim=-1)
|
||||
|
||||
# Rearrange + contiguous
|
||||
q = q_flat.view(L, H, K).contiguous()
|
||||
k = k_flat.view(L, H, K).contiguous()
|
||||
v = v_flat.view(L, HV, V).contiguous()
|
||||
|
||||
# L2 norm
|
||||
if apply_l2norm:
|
||||
q = F.normalize(q.float(), p=2, dim=-1, eps=1e-6).to(conv_output.dtype)
|
||||
k = F.normalize(k.float(), p=2, dim=-1, eps=1e-6).to(conv_output.dtype)
|
||||
|
||||
# Gating
|
||||
x = a.float() + dt_bias.float()
|
||||
sp = F.softplus(x, beta=1.0, threshold=20.0)
|
||||
g = -torch.exp(A_log.float()) * sp
|
||||
|
||||
if output_g_exp:
|
||||
g = torch.exp(g)
|
||||
|
||||
beta_out = torch.sigmoid(b.float())
|
||||
|
||||
return q, k, v, g, beta_out
|
||||
|
||||
|
||||
# Qwen3.5-35B config: H=16, HV=32, K=128, V=128
|
||||
# Qwen3.5-397B config: H=16, HV=64, K=128, V=128
|
||||
@pytest.mark.parametrize(
|
||||
"H, HV, K, V",
|
||||
[
|
||||
(16, 32, 128, 128), # 35B
|
||||
(16, 64, 128, 128), # 397B
|
||||
(4, 8, 64, 64), # small
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("L", [1, 16, 128, 512, 2048])
|
||||
@pytest.mark.parametrize("apply_l2norm", [True, False])
|
||||
@pytest.mark.parametrize("output_g_exp", [True, False])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
def test_fused_post_conv_correctness(H, HV, K, V, L, apply_l2norm, output_g_exp, dtype):
|
||||
"""Test fused kernel matches reference for all configs."""
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
qkv_dim = 2 * H * K + HV * V
|
||||
|
||||
conv_output = torch.randn(L, qkv_dim, dtype=dtype, device=device)
|
||||
a = torch.randn(L, HV, dtype=dtype, device=device)
|
||||
b = torch.randn(L, HV, dtype=dtype, device=device)
|
||||
A_log = torch.randn(HV, dtype=torch.float32, device=device) - 2.0
|
||||
dt_bias = torch.randn(HV, dtype=torch.float32, device=device) * 0.1
|
||||
|
||||
# Reference
|
||||
ref_q, ref_k, ref_v, ref_g, ref_beta = reference_post_conv(
|
||||
conv_output,
|
||||
a,
|
||||
b,
|
||||
A_log,
|
||||
dt_bias,
|
||||
H,
|
||||
K,
|
||||
V,
|
||||
apply_l2norm,
|
||||
output_g_exp,
|
||||
)
|
||||
|
||||
# Fused kernel
|
||||
fused_q, fused_k, fused_v, fused_g, fused_beta = fused_post_conv_prep(
|
||||
conv_output,
|
||||
a,
|
||||
b,
|
||||
A_log,
|
||||
dt_bias,
|
||||
num_k_heads=H,
|
||||
head_k_dim=K,
|
||||
head_v_dim=V,
|
||||
apply_l2norm=apply_l2norm,
|
||||
output_g_exp=output_g_exp,
|
||||
)
|
||||
|
||||
# Check shapes
|
||||
assert fused_q.shape == (L, H, K), f"q shape: {fused_q.shape}"
|
||||
assert fused_k.shape == (L, H, K), f"k shape: {fused_k.shape}"
|
||||
assert fused_v.shape == (L, HV, V), f"v shape: {fused_v.shape}"
|
||||
assert fused_g.shape == (L, HV), f"g shape: {fused_g.shape}"
|
||||
assert fused_beta.shape == (L, HV), f"beta shape: {fused_beta.shape}"
|
||||
|
||||
# Check dtypes
|
||||
assert fused_q.dtype == dtype
|
||||
assert fused_k.dtype == dtype
|
||||
assert fused_v.dtype == dtype
|
||||
assert fused_g.dtype == torch.float32
|
||||
assert fused_beta.dtype == torch.float32
|
||||
|
||||
# Check contiguity
|
||||
assert fused_q.is_contiguous()
|
||||
assert fused_k.is_contiguous()
|
||||
assert fused_v.is_contiguous()
|
||||
|
||||
# Check values
|
||||
atol_qkv = 1e-2 if apply_l2norm else 1e-3
|
||||
rtol_qkv = 1e-2 if apply_l2norm else 1e-3
|
||||
|
||||
torch.testing.assert_close(fused_q, ref_q, atol=atol_qkv, rtol=rtol_qkv)
|
||||
torch.testing.assert_close(fused_k, ref_k, atol=atol_qkv, rtol=rtol_qkv)
|
||||
torch.testing.assert_close(fused_v, ref_v, atol=1e-3, rtol=1e-3)
|
||||
torch.testing.assert_close(fused_g, ref_g, atol=1e-4, rtol=1e-4)
|
||||
torch.testing.assert_close(fused_beta, ref_beta, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("L", [1, 64, 256])
|
||||
def test_fused_post_conv_sanity(L):
|
||||
"""Sanity checks: no NaN, unit-norm q/k, beta in (0,1)."""
|
||||
torch.manual_seed(0)
|
||||
device = "cuda"
|
||||
H, HV, K, V = 16, 32, 128, 128
|
||||
qkv_dim = 2 * H * K + HV * V
|
||||
|
||||
conv_output = torch.randn(L, qkv_dim, dtype=torch.bfloat16, device=device)
|
||||
a = torch.randn(L, HV, dtype=torch.bfloat16, device=device)
|
||||
b = torch.randn(L, HV, dtype=torch.bfloat16, device=device)
|
||||
A_log = torch.randn(HV, dtype=torch.float32, device=device) - 2.0
|
||||
dt_bias = torch.randn(HV, dtype=torch.float32, device=device)
|
||||
|
||||
q, k, v, g, beta = fused_post_conv_prep(
|
||||
conv_output,
|
||||
a,
|
||||
b,
|
||||
A_log,
|
||||
dt_bias,
|
||||
num_k_heads=H,
|
||||
head_k_dim=K,
|
||||
head_v_dim=V,
|
||||
)
|
||||
|
||||
# Basic sanity
|
||||
assert not torch.isnan(q).any(), "NaN in q"
|
||||
assert not torch.isnan(k).any(), "NaN in k"
|
||||
assert not torch.isnan(v).any(), "NaN in v"
|
||||
assert not torch.isnan(g).any(), "NaN in g"
|
||||
assert not torch.isnan(beta).any(), "NaN in beta"
|
||||
|
||||
# L2 norm check: each head vector should have unit norm
|
||||
q_norms = torch.norm(q.float(), dim=-1)
|
||||
k_norms = torch.norm(k.float(), dim=-1)
|
||||
torch.testing.assert_close(q_norms, torch.ones_like(q_norms), atol=1e-3, rtol=1e-3)
|
||||
torch.testing.assert_close(k_norms, torch.ones_like(k_norms), atol=1e-3, rtol=1e-3)
|
||||
|
||||
# Beta should be in (0, 1)
|
||||
assert (beta >= 0).all() and (beta <= 1).all(), "beta out of range"
|
||||
|
||||
|
||||
def test_fused_post_conv_l0():
|
||||
"""Test L=0 edge case."""
|
||||
device = "cuda"
|
||||
H, HV, K, V = 16, 32, 128, 128
|
||||
qkv_dim = 2 * H * K + HV * V
|
||||
|
||||
conv_output = torch.empty(0, qkv_dim, dtype=torch.bfloat16, device=device)
|
||||
a = torch.empty(0, HV, dtype=torch.bfloat16, device=device)
|
||||
b = torch.empty(0, HV, dtype=torch.bfloat16, device=device)
|
||||
A_log = torch.randn(HV, dtype=torch.float32, device=device)
|
||||
dt_bias = torch.randn(HV, dtype=torch.float32, device=device)
|
||||
|
||||
q, k, v, g, beta = fused_post_conv_prep(
|
||||
conv_output,
|
||||
a,
|
||||
b,
|
||||
A_log,
|
||||
dt_bias,
|
||||
num_k_heads=H,
|
||||
head_k_dim=K,
|
||||
head_v_dim=V,
|
||||
)
|
||||
assert q.shape == (0, H, K)
|
||||
assert g.shape == (0, HV)
|
||||
@@ -7,6 +7,7 @@
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
from .chunk import chunk_gated_delta_rule
|
||||
from .fused_gdn_prefill_post_conv import fused_post_conv_prep
|
||||
from .fused_recurrent import (
|
||||
fused_recurrent_gated_delta_rule,
|
||||
fused_recurrent_gated_delta_rule_packed_decode,
|
||||
@@ -19,5 +20,6 @@ __all__ = [
|
||||
"chunk_gated_delta_rule",
|
||||
"fused_recurrent_gated_delta_rule",
|
||||
"fused_recurrent_gated_delta_rule_packed_decode",
|
||||
"fused_post_conv_prep",
|
||||
"fused_sigmoid_gating_delta_rule_update",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Fused post-conv1d preparation for GDN prefill.
|
||||
|
||||
Replaces the chain:
|
||||
split → rearrange → contiguous * 3 → l2norm * 2 → gating
|
||||
with a **single Triton kernel** that reads the conv'd mixed_qkv output
|
||||
and writes directly to q/k/v/g/beta in the target contiguous layout.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_post_conv_kernel(
|
||||
# ---- inputs ----
|
||||
mixed_qkv_ptr, # [L, qkv_dim] conv'd output (contiguous)
|
||||
a_ptr, # [L, HV]
|
||||
b_ptr, # [L, HV]
|
||||
# ---- params ----
|
||||
A_log_ptr, # [HV]
|
||||
dt_bias_ptr, # [HV]
|
||||
# ---- outputs ----
|
||||
q_ptr, # [L, H, K] contiguous
|
||||
k_ptr, # [L, H, K] contiguous
|
||||
v_ptr, # [L, HV, V] contiguous
|
||||
g_ptr, # [L, HV] float32
|
||||
beta_ptr, # [L, HV] float32
|
||||
# ---- strides ----
|
||||
stride_x_tok, # qkv_dim
|
||||
stride_a_tok, # HV
|
||||
stride_b_tok, # HV
|
||||
stride_q_tok, # H * K
|
||||
stride_k_tok, # H * K
|
||||
stride_v_tok, # HV * V
|
||||
# ---- dims ----
|
||||
L,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
APPLY_L2NORM: tl.constexpr,
|
||||
L2NORM_EPS: tl.constexpr,
|
||||
OUTPUT_G_EXP: tl.constexpr,
|
||||
SOFTPLUS_THRESHOLD: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
):
|
||||
"""Single fused kernel for post-conv1d preparation.
|
||||
|
||||
Grid: (ceil(L, BLOCK_T), H + HV)
|
||||
- program_id(1) in [0, H): Q/K head processing + l2norm
|
||||
- program_id(1) in [H, H+HV): V head processing + gating
|
||||
"""
|
||||
i_tb = tl.program_id(0)
|
||||
i_head = tl.program_id(1)
|
||||
|
||||
HK: tl.constexpr = H * K
|
||||
|
||||
offs_t = i_tb * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||
mask_t = offs_t < L
|
||||
|
||||
if i_head < H:
|
||||
# ============ Q/K head processing ============
|
||||
i_h = i_head
|
||||
offs_k = tl.arange(0, BK) # [BK]
|
||||
mask_k = offs_k < K
|
||||
mask_2d = mask_t[:, None] & mask_k[None, :] # [BLOCK_T, BK]
|
||||
|
||||
# Load Q features: mixed_qkv[t, i_h*K + k]
|
||||
q_offsets = offs_t[:, None] * stride_x_tok + i_h * K + offs_k[None, :]
|
||||
q_f32 = tl.load(mixed_qkv_ptr + q_offsets, mask=mask_2d, other=0).to(tl.float32)
|
||||
|
||||
# Load K features: mixed_qkv[t, HK + i_h*K + k]
|
||||
k_offsets = offs_t[:, None] * stride_x_tok + HK + i_h * K + offs_k[None, :]
|
||||
k_f32 = tl.load(mixed_qkv_ptr + k_offsets, mask=mask_2d, other=0).to(tl.float32)
|
||||
|
||||
if APPLY_L2NORM:
|
||||
q_sq_sum = tl.sum(q_f32 * q_f32, axis=1) # [BLOCK_T]
|
||||
q_inv = 1.0 / tl.sqrt(q_sq_sum + L2NORM_EPS)
|
||||
q_f32 = q_f32 * q_inv[:, None]
|
||||
|
||||
k_sq_sum = tl.sum(k_f32 * k_f32, axis=1)
|
||||
k_inv = 1.0 / tl.sqrt(k_sq_sum + L2NORM_EPS)
|
||||
k_f32 = k_f32 * k_inv[:, None]
|
||||
|
||||
# Store Q
|
||||
q_out = offs_t[:, None] * stride_q_tok + i_h * K + offs_k[None, :]
|
||||
tl.store(
|
||||
q_ptr + q_out,
|
||||
q_f32.to(q_ptr.dtype.element_ty),
|
||||
mask=mask_2d,
|
||||
)
|
||||
|
||||
# Store K
|
||||
k_out = offs_t[:, None] * stride_k_tok + i_h * K + offs_k[None, :]
|
||||
tl.store(
|
||||
k_ptr + k_out,
|
||||
k_f32.to(k_ptr.dtype.element_ty),
|
||||
mask=mask_2d,
|
||||
)
|
||||
else:
|
||||
# ============ V head + gating processing ============
|
||||
i_hv = i_head - H
|
||||
offs_v = tl.arange(0, BV) # [BV]
|
||||
mask_v = offs_v < V
|
||||
mask_2d = mask_t[:, None] & mask_v[None, :] # [BLOCK_T, BV]
|
||||
|
||||
V_OFFSET: tl.constexpr = 2 * H * K
|
||||
|
||||
# Load V features: mixed_qkv[t, 2*H*K + i_hv*V + v]
|
||||
v_offsets = (
|
||||
offs_t[:, None] * stride_x_tok + V_OFFSET + i_hv * V + offs_v[None, :]
|
||||
)
|
||||
v_vals = tl.load(mixed_qkv_ptr + v_offsets, mask=mask_2d, other=0)
|
||||
|
||||
# Store V
|
||||
v_out = offs_t[:, None] * stride_v_tok + i_hv * V + offs_v[None, :]
|
||||
tl.store(v_ptr + v_out, v_vals, mask=mask_2d)
|
||||
|
||||
# Gating: one scalar per (token, v-head)
|
||||
A_log_val = tl.load(A_log_ptr + i_hv).to(tl.float32)
|
||||
dt_bias_val = tl.load(dt_bias_ptr + i_hv).to(tl.float32)
|
||||
|
||||
a_offsets = offs_t * stride_a_tok + i_hv
|
||||
b_offsets = offs_t * stride_b_tok + i_hv
|
||||
a_vals = tl.load(a_ptr + a_offsets, mask=mask_t, other=0).to(tl.float32)
|
||||
b_vals = tl.load(b_ptr + b_offsets, mask=mask_t, other=0).to(tl.float32)
|
||||
|
||||
# g = -exp(A_log) * softplus(a + dt_bias)
|
||||
x = a_vals + dt_bias_val
|
||||
sp = tl.where(x > 0, x + tl.log(1.0 + tl.exp(-x)), tl.log(1.0 + tl.exp(x)))
|
||||
sp = tl.where(x <= SOFTPLUS_THRESHOLD, sp, x)
|
||||
g_vals = -tl.exp(A_log_val) * sp
|
||||
|
||||
if OUTPUT_G_EXP:
|
||||
g_vals = tl.exp(g_vals)
|
||||
|
||||
beta_vals = tl.sigmoid(b_vals)
|
||||
|
||||
gb_offsets = offs_t * HV + i_hv
|
||||
tl.store(g_ptr + gb_offsets, g_vals, mask=mask_t)
|
||||
tl.store(beta_ptr + gb_offsets, beta_vals, mask=mask_t)
|
||||
|
||||
|
||||
def fused_post_conv_prep(
|
||||
conv_output: torch.Tensor, # [L, qkv_dim] conv'd mixed_qkv
|
||||
a: torch.Tensor, # [L, HV]
|
||||
b: torch.Tensor, # [L, HV]
|
||||
A_log: torch.Tensor, # [HV]
|
||||
dt_bias: torch.Tensor, # [HV]
|
||||
num_k_heads: int,
|
||||
head_k_dim: int,
|
||||
head_v_dim: int,
|
||||
apply_l2norm: bool = True,
|
||||
output_g_exp: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Fused post-conv1d prep: split + l2norm + gating in one kernel.
|
||||
|
||||
Args:
|
||||
conv_output: [L, qkv_dim] contiguous conv'd mixed_qkv
|
||||
a: [L, HV] gating input
|
||||
b: [L, HV] gating input
|
||||
A_log: [HV] log decay parameter
|
||||
dt_bias: [HV] dt bias parameter
|
||||
num_k_heads: number of K heads (H)
|
||||
head_k_dim: dimension per K head (K)
|
||||
head_v_dim: dimension per V head (V)
|
||||
apply_l2norm: whether to L2-normalize q and k
|
||||
output_g_exp: if True, output exp(g) instead of g (for FlashInfer)
|
||||
|
||||
Returns:
|
||||
q: [L, H, K] contiguous, optionally l2-normalized
|
||||
k: [L, H, K] contiguous, optionally l2-normalized
|
||||
v: [L, HV, V] contiguous
|
||||
g: [L, HV] float32
|
||||
beta: [L, HV] float32
|
||||
"""
|
||||
L = conv_output.shape[0]
|
||||
qkv_dim = conv_output.shape[1]
|
||||
H = num_k_heads
|
||||
K = head_k_dim
|
||||
V = head_v_dim
|
||||
HV = A_log.shape[0]
|
||||
dtype = conv_output.dtype
|
||||
device = conv_output.device
|
||||
|
||||
assert qkv_dim == 2 * H * K + HV * V, (
|
||||
f"qkv_dim={qkv_dim} != 2*H*K + HV*V = {2 * H * K + HV * V}"
|
||||
)
|
||||
|
||||
# Allocate outputs in target contiguous layout
|
||||
q = torch.empty(L, H, K, dtype=dtype, device=device)
|
||||
k = torch.empty(L, H, K, dtype=dtype, device=device)
|
||||
v = torch.empty(L, HV, V, dtype=dtype, device=device)
|
||||
g = torch.empty(L, HV, dtype=torch.float32, device=device)
|
||||
beta = torch.empty(L, HV, dtype=torch.float32, device=device)
|
||||
|
||||
if L == 0:
|
||||
return q, k, v, g, beta
|
||||
|
||||
# ---- Kernel config ----
|
||||
BK = triton.next_power_of_2(K)
|
||||
BV = triton.next_power_of_2(V)
|
||||
BLOCK_T = 16 # tokens per block
|
||||
|
||||
# Single kernel: blocks [0,H) do Q/K, blocks [H, H+HV) do V+gating
|
||||
grid = (triton.cdiv(L, BLOCK_T), H + HV)
|
||||
_fused_post_conv_kernel[grid](
|
||||
mixed_qkv_ptr=conv_output,
|
||||
a_ptr=a,
|
||||
b_ptr=b,
|
||||
A_log_ptr=A_log,
|
||||
dt_bias_ptr=dt_bias,
|
||||
q_ptr=q,
|
||||
k_ptr=k,
|
||||
v_ptr=v,
|
||||
g_ptr=g,
|
||||
beta_ptr=beta,
|
||||
stride_x_tok=conv_output.stride(0),
|
||||
stride_a_tok=a.stride(0),
|
||||
stride_b_tok=b.stride(0),
|
||||
stride_q_tok=q.stride(0),
|
||||
stride_k_tok=k.stride(0),
|
||||
stride_v_tok=v.stride(0),
|
||||
L=L,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
APPLY_L2NORM=apply_l2norm,
|
||||
L2NORM_EPS=1e-6,
|
||||
OUTPUT_G_EXP=output_g_exp,
|
||||
SOFTPLUS_THRESHOLD=20.0,
|
||||
BLOCK_T=BLOCK_T,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
)
|
||||
|
||||
return q, k, v, g, beta
|
||||
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.fla.ops import (
|
||||
chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
|
||||
)
|
||||
from vllm.model_executor.layers.fla.ops import (
|
||||
fused_post_conv_prep,
|
||||
fused_recurrent_gated_delta_rule_packed_decode,
|
||||
fused_sigmoid_gating_delta_rule_update,
|
||||
)
|
||||
@@ -774,19 +775,44 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
mixed_qkv_non_spec = None
|
||||
|
||||
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
|
||||
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
|
||||
mixed_qkv_non_spec
|
||||
)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
|
||||
assert mixed_qkv_non_spec is not None, (
|
||||
"mixed_qkv_non_spec must be provided for prefill path"
|
||||
)
|
||||
if spec_sequence_masks is not None:
|
||||
g_non_spec = g.index_select(1, non_spec_token_indx)
|
||||
beta_non_spec = beta.index_select(1, non_spec_token_indx)
|
||||
a_non_spec = a.index_select(0, non_spec_token_indx)
|
||||
b_non_spec = b.index_select(0, non_spec_token_indx)
|
||||
else:
|
||||
g_non_spec = g
|
||||
beta_non_spec = beta
|
||||
a_non_spec = a
|
||||
b_non_spec = b
|
||||
|
||||
(
|
||||
query_non_spec,
|
||||
key_non_spec,
|
||||
value_non_spec,
|
||||
g_non_spec,
|
||||
beta_non_spec,
|
||||
) = fused_post_conv_prep(
|
||||
conv_output=mixed_qkv_non_spec,
|
||||
a=a_non_spec,
|
||||
b=b_non_spec,
|
||||
A_log=self.A_log,
|
||||
dt_bias=self.dt_bias,
|
||||
num_k_heads=self.num_k_heads // self.tp_size,
|
||||
head_k_dim=self.head_k_dim,
|
||||
head_v_dim=self.head_v_dim,
|
||||
apply_l2norm=True,
|
||||
output_g_exp=False,
|
||||
)
|
||||
query_non_spec = query_non_spec.unsqueeze(0)
|
||||
key_non_spec = key_non_spec.unsqueeze(0)
|
||||
value_non_spec = value_non_spec.unsqueeze(0)
|
||||
g_non_spec = g_non_spec.unsqueeze(0)
|
||||
beta_non_spec = beta_non_spec.unsqueeze(0)
|
||||
else:
|
||||
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
|
||||
mixed_qkv_non_spec
|
||||
)
|
||||
g_non_spec = None
|
||||
beta_non_spec = None
|
||||
|
||||
@@ -832,7 +858,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
initial_state=initial_state,
|
||||
output_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
use_qk_l2norm_in_kernel=False,
|
||||
)
|
||||
# Init cache
|
||||
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
|
||||
|
||||
Reference in New Issue
Block a user