[Perf] add packed recurrent fast path for decode (#36596)
Signed-off-by: hdj <1293066020@qq.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
98
tests/kernels/test_fused_recurrent_packed_decode.py
Normal file
98
tests/kernels/test_fused_recurrent_packed_decode.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fla.ops import (
|
||||||
|
fused_recurrent_gated_delta_rule,
|
||||||
|
fused_recurrent_gated_delta_rule_packed_decode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need CUDA device")
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||||
|
@pytest.mark.parametrize("strided_mixed_qkv", [False, True])
|
||||||
|
def test_fused_recurrent_packed_decode_matches_reference(
|
||||||
|
dtype: torch.dtype, strided_mixed_qkv: bool
|
||||||
|
):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
# Small but representative GDN config (Qwen3Next defaults are K=128, V=128).
|
||||||
|
B = 32
|
||||||
|
H = 4
|
||||||
|
HV = 8 # grouped value attention: HV must be divisible by H
|
||||||
|
K = 128
|
||||||
|
V = 128
|
||||||
|
qkv_dim = 2 * (H * K) + (HV * V)
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
if strided_mixed_qkv:
|
||||||
|
# Simulate a packed view into a larger projection buffer:
|
||||||
|
# mixed_qkv.stride(0) > mixed_qkv.shape[1]
|
||||||
|
proj = torch.randn((B, qkv_dim + 64), device=device, dtype=dtype)
|
||||||
|
mixed_qkv = proj[:, :qkv_dim]
|
||||||
|
else:
|
||||||
|
mixed_qkv = torch.randn((B, qkv_dim), device=device, dtype=dtype)
|
||||||
|
|
||||||
|
a = torch.randn((B, HV), device=device, dtype=dtype)
|
||||||
|
b = torch.randn((B, HV), device=device, dtype=dtype)
|
||||||
|
A_log = torch.randn((HV,), device=device, dtype=dtype)
|
||||||
|
dt_bias = torch.randn((HV,), device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Continuous batching indices (include PAD_SLOT_ID=-1 cases).
|
||||||
|
ssm_state_indices = torch.arange(B, device=device, dtype=torch.int32)
|
||||||
|
ssm_state_indices[-3:] = -1
|
||||||
|
|
||||||
|
state0 = torch.randn((B, HV, V, K), device=device, dtype=dtype)
|
||||||
|
state_ref = state0.clone()
|
||||||
|
state_packed = state0.clone()
|
||||||
|
|
||||||
|
out_packed = torch.empty((B, 1, HV, V), device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Reference path: materialize contiguous Q/K/V + explicit gating.
|
||||||
|
q, k, v = torch.split(mixed_qkv, [H * K, H * K, HV * V], dim=-1)
|
||||||
|
q = q.view(B, H, K).unsqueeze(1).contiguous()
|
||||||
|
k = k.view(B, H, K).unsqueeze(1).contiguous()
|
||||||
|
v = v.view(B, HV, V).unsqueeze(1).contiguous()
|
||||||
|
|
||||||
|
x = a.float() + dt_bias.float()
|
||||||
|
softplus_x = torch.where(
|
||||||
|
x <= 20.0, torch.log1p(torch.exp(torch.clamp(x, max=20.0))), x
|
||||||
|
)
|
||||||
|
g = (-torch.exp(A_log.float()) * softplus_x).unsqueeze(1)
|
||||||
|
beta = torch.sigmoid(b.float()).to(dtype).unsqueeze(1)
|
||||||
|
|
||||||
|
out_ref, state_ref = fused_recurrent_gated_delta_rule(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
g=g,
|
||||||
|
beta=beta,
|
||||||
|
scale=K**-0.5,
|
||||||
|
initial_state=state_ref,
|
||||||
|
inplace_final_state=True,
|
||||||
|
cu_seqlens=None,
|
||||||
|
ssm_state_indices=ssm_state_indices,
|
||||||
|
use_qk_l2norm_in_kernel=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Packed path: fused gating + recurrent directly from packed mixed_qkv.
|
||||||
|
fused_recurrent_gated_delta_rule_packed_decode(
|
||||||
|
mixed_qkv=mixed_qkv,
|
||||||
|
a=a,
|
||||||
|
b=b,
|
||||||
|
A_log=A_log,
|
||||||
|
dt_bias=dt_bias,
|
||||||
|
scale=K**-0.5,
|
||||||
|
initial_state=state_packed,
|
||||||
|
out=out_packed,
|
||||||
|
ssm_state_indices=ssm_state_indices,
|
||||||
|
use_qk_l2norm_in_kernel=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
atol = 2e-2 if dtype != torch.float32 else 1e-4
|
||||||
|
rtol = 1e-2 if dtype != torch.float32 else 1e-4
|
||||||
|
torch.testing.assert_close(out_packed, out_ref, rtol=rtol, atol=atol)
|
||||||
|
torch.testing.assert_close(state_packed, state_ref, rtol=rtol, atol=atol)
|
||||||
@@ -96,6 +96,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||||
VLLM_SKIP_P2P_CHECK: bool = False
|
VLLM_SKIP_P2P_CHECK: bool = False
|
||||||
VLLM_DISABLED_KERNELS: list[str] = []
|
VLLM_DISABLED_KERNELS: list[str] = []
|
||||||
|
VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE: bool = True
|
||||||
VLLM_DISABLE_PYNCCL: bool = False
|
VLLM_DISABLE_PYNCCL: bool = False
|
||||||
VLLM_USE_OINK_OPS: bool = False
|
VLLM_USE_OINK_OPS: bool = False
|
||||||
VLLM_ROCM_USE_AITER: bool = False
|
VLLM_ROCM_USE_AITER: bool = False
|
||||||
@@ -899,6 +900,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_DISABLED_KERNELS": lambda: []
|
"VLLM_DISABLED_KERNELS": lambda: []
|
||||||
if "VLLM_DISABLED_KERNELS" not in os.environ
|
if "VLLM_DISABLED_KERNELS" not in os.environ
|
||||||
else os.environ["VLLM_DISABLED_KERNELS"].split(","),
|
else os.environ["VLLM_DISABLED_KERNELS"].split(","),
|
||||||
|
"VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE": lambda: bool(
|
||||||
|
int(os.getenv("VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE", "1"))
|
||||||
|
),
|
||||||
# Disable pynccl (using torch.distributed instead)
|
# Disable pynccl (using torch.distributed instead)
|
||||||
"VLLM_DISABLE_PYNCCL": lambda: (
|
"VLLM_DISABLE_PYNCCL": lambda: (
|
||||||
os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")
|
os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")
|
||||||
|
|||||||
@@ -7,7 +7,10 @@
|
|||||||
# the following copyright notice:
|
# the following copyright notice:
|
||||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||||
from .chunk import chunk_gated_delta_rule
|
from .chunk import chunk_gated_delta_rule
|
||||||
from .fused_recurrent import fused_recurrent_gated_delta_rule
|
from .fused_recurrent import (
|
||||||
|
fused_recurrent_gated_delta_rule,
|
||||||
|
fused_recurrent_gated_delta_rule_packed_decode,
|
||||||
|
)
|
||||||
from .fused_sigmoid_gating import fused_sigmoid_gating_delta_rule_update
|
from .fused_sigmoid_gating import fused_sigmoid_gating_delta_rule_update
|
||||||
from .layernorm_guard import RMSNormGated
|
from .layernorm_guard import RMSNormGated
|
||||||
|
|
||||||
@@ -15,5 +18,6 @@ __all__ = [
|
|||||||
"RMSNormGated",
|
"RMSNormGated",
|
||||||
"chunk_gated_delta_rule",
|
"chunk_gated_delta_rule",
|
||||||
"fused_recurrent_gated_delta_rule",
|
"fused_recurrent_gated_delta_rule",
|
||||||
|
"fused_recurrent_gated_delta_rule_packed_decode",
|
||||||
"fused_sigmoid_gating_delta_rule_update",
|
"fused_sigmoid_gating_delta_rule_update",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -252,6 +252,231 @@ def fused_recurrent_gated_delta_rule_fwd(
|
|||||||
return o, final_state
|
return o, final_state
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def fused_recurrent_gated_delta_rule_packed_decode_kernel(
|
||||||
|
mixed_qkv,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
A_log,
|
||||||
|
dt_bias,
|
||||||
|
o,
|
||||||
|
h0,
|
||||||
|
ht,
|
||||||
|
ssm_state_indices,
|
||||||
|
scale,
|
||||||
|
stride_mixed_qkv_tok: tl.constexpr,
|
||||||
|
stride_a_tok: tl.constexpr,
|
||||||
|
stride_b_tok: tl.constexpr,
|
||||||
|
stride_init_state_token: tl.constexpr,
|
||||||
|
stride_final_state_token: tl.constexpr,
|
||||||
|
stride_indices_seq: tl.constexpr,
|
||||||
|
H: tl.constexpr,
|
||||||
|
HV: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
V: tl.constexpr,
|
||||||
|
BK: tl.constexpr,
|
||||||
|
BV: tl.constexpr,
|
||||||
|
SOFTPLUS_THRESHOLD: tl.constexpr,
|
||||||
|
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||||
|
):
|
||||||
|
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||||
|
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||||
|
i_h = i_hv // (HV // H)
|
||||||
|
|
||||||
|
o_k = tl.arange(0, BK)
|
||||||
|
o_v = i_v * BV + tl.arange(0, BV)
|
||||||
|
mask_k = o_k < K
|
||||||
|
mask_v = o_v < V
|
||||||
|
mask_h = mask_v[:, None] & mask_k[None, :]
|
||||||
|
|
||||||
|
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64)
|
||||||
|
p_o = o + (i_n * HV + i_hv) * V + o_v
|
||||||
|
|
||||||
|
if state_idx < 0:
|
||||||
|
zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty)
|
||||||
|
tl.store(p_o, zero, mask=mask_v)
|
||||||
|
return
|
||||||
|
|
||||||
|
p_h0 = h0 + state_idx * stride_init_state_token
|
||||||
|
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||||
|
b_h = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||||
|
|
||||||
|
p_mixed = mixed_qkv + i_n * stride_mixed_qkv_tok
|
||||||
|
q_off = i_h * K + o_k
|
||||||
|
k_off = (H * K) + i_h * K + o_k
|
||||||
|
v_off = (2 * H * K) + i_hv * V + o_v
|
||||||
|
b_q = tl.load(p_mixed + q_off, mask=mask_k, other=0).to(tl.float32)
|
||||||
|
b_k = tl.load(p_mixed + k_off, mask=mask_k, other=0).to(tl.float32)
|
||||||
|
b_v = tl.load(p_mixed + v_off, mask=mask_v, other=0).to(tl.float32)
|
||||||
|
|
||||||
|
if USE_QK_L2NORM_IN_KERNEL:
|
||||||
|
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
||||||
|
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
||||||
|
b_q = b_q * scale
|
||||||
|
|
||||||
|
a_val = tl.load(a + i_n * stride_a_tok + i_hv).to(tl.float32)
|
||||||
|
b_val = tl.load(b + i_n * stride_b_tok + i_hv).to(tl.float32)
|
||||||
|
A_log_val = tl.load(A_log + i_hv).to(tl.float32)
|
||||||
|
dt_bias_val = tl.load(dt_bias + i_hv).to(tl.float32)
|
||||||
|
x = a_val + dt_bias_val
|
||||||
|
softplus_x = tl.where(x <= SOFTPLUS_THRESHOLD, tl.log(1.0 + tl.exp(x)), x)
|
||||||
|
g_val = -tl.exp(A_log_val) * softplus_x
|
||||||
|
beta_val = tl.sigmoid(b_val).to(b.dtype.element_ty).to(tl.float32)
|
||||||
|
|
||||||
|
b_h *= exp(g_val)
|
||||||
|
b_v -= tl.sum(b_h * b_k[None, :], 1)
|
||||||
|
b_v *= beta_val
|
||||||
|
b_h += b_v[:, None] * b_k[None, :]
|
||||||
|
b_o = tl.sum(b_h * b_q[None, :], 1)
|
||||||
|
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||||
|
|
||||||
|
p_ht = ht + state_idx * stride_final_state_token
|
||||||
|
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||||
|
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_recurrent_gated_delta_rule_packed_decode(
|
||||||
|
mixed_qkv: torch.Tensor,
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
A_log: torch.Tensor,
|
||||||
|
dt_bias: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
initial_state: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
ssm_state_indices: torch.Tensor,
|
||||||
|
use_qk_l2norm_in_kernel: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if mixed_qkv.ndim != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"`mixed_qkv` must be a 2D tensor (got ndim={mixed_qkv.ndim})."
|
||||||
|
)
|
||||||
|
if mixed_qkv.stride(-1) != 1:
|
||||||
|
raise ValueError("`mixed_qkv` must be contiguous in the last dim.")
|
||||||
|
if a.ndim != 2 or b.ndim != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"`a` and `b` must be 2D tensors (got a.ndim={a.ndim}, b.ndim={b.ndim})."
|
||||||
|
)
|
||||||
|
if a.stride(-1) != 1 or b.stride(-1) != 1:
|
||||||
|
raise ValueError("`a`/`b` must be contiguous in the last dim.")
|
||||||
|
if A_log.ndim != 1 or dt_bias.ndim != 1:
|
||||||
|
raise ValueError("`A_log`/`dt_bias` must be 1D tensors.")
|
||||||
|
if A_log.stride(0) != 1 or dt_bias.stride(0) != 1:
|
||||||
|
raise ValueError("`A_log`/`dt_bias` must be contiguous.")
|
||||||
|
if ssm_state_indices.ndim != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"`ssm_state_indices` must be 1D for packed decode (got ndim={ssm_state_indices.ndim})."
|
||||||
|
)
|
||||||
|
if not out.is_contiguous():
|
||||||
|
raise ValueError("`out` must be contiguous.")
|
||||||
|
|
||||||
|
dev = mixed_qkv.device
|
||||||
|
if (
|
||||||
|
a.device != dev
|
||||||
|
or b.device != dev
|
||||||
|
or A_log.device != dev
|
||||||
|
or dt_bias.device != dev
|
||||||
|
or initial_state.device != dev
|
||||||
|
or out.device != dev
|
||||||
|
or ssm_state_indices.device != dev
|
||||||
|
):
|
||||||
|
raise ValueError("All inputs must be on the same device.")
|
||||||
|
|
||||||
|
B = mixed_qkv.shape[0]
|
||||||
|
if a.shape[0] != B or b.shape[0] != B:
|
||||||
|
raise ValueError(
|
||||||
|
"Mismatched batch sizes: "
|
||||||
|
f"mixed_qkv.shape[0]={B}, a.shape[0]={a.shape[0]}, b.shape[0]={b.shape[0]}."
|
||||||
|
)
|
||||||
|
if ssm_state_indices.shape[0] != B:
|
||||||
|
raise ValueError(
|
||||||
|
f"`ssm_state_indices` must have shape [B] (got {tuple(ssm_state_indices.shape)}; expected ({B},))."
|
||||||
|
)
|
||||||
|
|
||||||
|
if initial_state.ndim != 4:
|
||||||
|
raise ValueError(
|
||||||
|
f"`initial_state` must be a 4D tensor (got ndim={initial_state.ndim})."
|
||||||
|
)
|
||||||
|
if initial_state.stride(-1) != 1:
|
||||||
|
raise ValueError("`initial_state` must be contiguous in the last dim.")
|
||||||
|
HV, V, K = initial_state.shape[-3:]
|
||||||
|
if a.shape[1] != HV or b.shape[1] != HV:
|
||||||
|
raise ValueError(
|
||||||
|
f"`a`/`b` must have shape [B, HV] with HV={HV} (got a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)})."
|
||||||
|
)
|
||||||
|
if A_log.numel() != HV or dt_bias.numel() != HV:
|
||||||
|
raise ValueError(
|
||||||
|
f"`A_log` and `dt_bias` must have {HV} elements (got A_log.numel()={A_log.numel()}, dt_bias.numel()={dt_bias.numel()})."
|
||||||
|
)
|
||||||
|
if out.shape != (B, 1, HV, V):
|
||||||
|
raise ValueError(
|
||||||
|
f"`out` must have shape {(B, 1, HV, V)} (got out.shape={tuple(out.shape)})."
|
||||||
|
)
|
||||||
|
|
||||||
|
qkv_dim = mixed_qkv.shape[1]
|
||||||
|
qk_dim = qkv_dim - HV * V
|
||||||
|
if qk_dim <= 0 or qk_dim % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid packed `mixed_qkv` last dim={qkv_dim} for HV={HV}, V={V}."
|
||||||
|
)
|
||||||
|
q_dim = qk_dim // 2
|
||||||
|
if q_dim % K != 0:
|
||||||
|
raise ValueError(f"Invalid packed Q size {q_dim}: must be divisible by K={K}.")
|
||||||
|
H = q_dim // K
|
||||||
|
if H <= 0 or HV % H != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid head config inferred from mixed_qkv: H={H}, HV={HV}."
|
||||||
|
)
|
||||||
|
|
||||||
|
BK = triton.next_power_of_2(K)
|
||||||
|
if triton.cdiv(K, BK) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Packed decode kernel only supports NK=1 (got K={K}, BK={BK})."
|
||||||
|
)
|
||||||
|
BV = min(triton.next_power_of_2(V), 32)
|
||||||
|
num_stages = 3
|
||||||
|
num_warps = 1
|
||||||
|
|
||||||
|
stride_mixed_qkv_tok = mixed_qkv.stride(0)
|
||||||
|
stride_a_tok = a.stride(0)
|
||||||
|
stride_b_tok = b.stride(0)
|
||||||
|
stride_init_state_token = initial_state.stride(0)
|
||||||
|
stride_final_state_token = initial_state.stride(0)
|
||||||
|
stride_indices_seq = ssm_state_indices.stride(0)
|
||||||
|
|
||||||
|
NV = triton.cdiv(V, BV)
|
||||||
|
grid = (NV, B * HV)
|
||||||
|
fused_recurrent_gated_delta_rule_packed_decode_kernel[grid](
|
||||||
|
mixed_qkv=mixed_qkv,
|
||||||
|
a=a,
|
||||||
|
b=b,
|
||||||
|
A_log=A_log,
|
||||||
|
dt_bias=dt_bias,
|
||||||
|
o=out,
|
||||||
|
h0=initial_state,
|
||||||
|
ht=initial_state,
|
||||||
|
ssm_state_indices=ssm_state_indices,
|
||||||
|
scale=scale,
|
||||||
|
stride_mixed_qkv_tok=stride_mixed_qkv_tok,
|
||||||
|
stride_a_tok=stride_a_tok,
|
||||||
|
stride_b_tok=stride_b_tok,
|
||||||
|
stride_init_state_token=stride_init_state_token,
|
||||||
|
stride_final_state_token=stride_final_state_token,
|
||||||
|
stride_indices_seq=stride_indices_seq,
|
||||||
|
H=H,
|
||||||
|
HV=HV,
|
||||||
|
K=K,
|
||||||
|
V=V,
|
||||||
|
BK=BK,
|
||||||
|
BV=BV,
|
||||||
|
SOFTPLUS_THRESHOLD=20.0,
|
||||||
|
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=num_stages,
|
||||||
|
)
|
||||||
|
return out, initial_state
|
||||||
|
|
||||||
|
|
||||||
class FusedRecurrentFunction(torch.autograd.Function):
|
class FusedRecurrentFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from einops import rearrange
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
@@ -34,6 +35,7 @@ from vllm.model_executor.layers.fla.ops import (
|
|||||||
chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
|
chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fla.ops import (
|
from vllm.model_executor.layers.fla.ops import (
|
||||||
|
fused_recurrent_gated_delta_rule_packed_decode,
|
||||||
fused_sigmoid_gating_delta_rule_update,
|
fused_sigmoid_gating_delta_rule_update,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
|
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
|
||||||
@@ -474,6 +476,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
|
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
|
||||||
|
self.enable_packed_recurrent_decode = (
|
||||||
|
envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
|
||||||
|
)
|
||||||
|
|
||||||
compilation_config = get_current_vllm_config().compilation_config
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
if prefix in compilation_config.static_forward_context:
|
if prefix in compilation_config.static_forward_context:
|
||||||
@@ -747,9 +752,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
core_attn_out: torch.Tensor,
|
core_attn_out: torch.Tensor,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Core attention computation (called by custom op).
|
|
||||||
"""
|
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||||
|
|
||||||
@@ -762,6 +764,22 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
attn_metadata = attn_metadata[self.prefix]
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.enable_packed_recurrent_decode
|
||||||
|
and attn_metadata.spec_sequence_masks is None
|
||||||
|
and attn_metadata.num_prefills == 0
|
||||||
|
and attn_metadata.num_decodes > 0
|
||||||
|
):
|
||||||
|
return self._forward_core_decode_non_spec(
|
||||||
|
mixed_qkv=mixed_qkv,
|
||||||
|
b=b,
|
||||||
|
a=a,
|
||||||
|
core_attn_out=core_attn_out,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
virtual_engine=forward_context.virtual_engine,
|
||||||
|
)
|
||||||
|
|
||||||
has_initial_state = attn_metadata.has_initial_state
|
has_initial_state = attn_metadata.has_initial_state
|
||||||
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
||||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||||
@@ -946,6 +964,55 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
else:
|
else:
|
||||||
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
|
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
|
||||||
|
|
||||||
|
def _forward_core_decode_non_spec(
|
||||||
|
self,
|
||||||
|
mixed_qkv: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
a: torch.Tensor,
|
||||||
|
core_attn_out: torch.Tensor,
|
||||||
|
attn_metadata: GDNAttentionMetadata,
|
||||||
|
virtual_engine: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Core attention computation with a packed non-spec decode fast path.
|
||||||
|
"""
|
||||||
|
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||||
|
self_kv_cache = self.kv_cache[virtual_engine]
|
||||||
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
|
ssm_state = self_kv_cache[1]
|
||||||
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
|
mixed_qkv = mixed_qkv[:num_actual_tokens]
|
||||||
|
b = b[:num_actual_tokens]
|
||||||
|
a = a[:num_actual_tokens]
|
||||||
|
|
||||||
|
conv_weights = self.conv1d.weight.view(
|
||||||
|
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||||
|
)
|
||||||
|
mixed_qkv_non_spec = causal_conv1d_update(
|
||||||
|
mixed_qkv,
|
||||||
|
conv_state,
|
||||||
|
conv_weights,
|
||||||
|
self.conv1d.bias,
|
||||||
|
self.activation,
|
||||||
|
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
|
||||||
|
validate_data=False,
|
||||||
|
)
|
||||||
|
out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
|
||||||
|
fused_recurrent_gated_delta_rule_packed_decode(
|
||||||
|
mixed_qkv=mixed_qkv_non_spec,
|
||||||
|
a=a,
|
||||||
|
b=b,
|
||||||
|
A_log=self.A_log,
|
||||||
|
dt_bias=self.dt_bias,
|
||||||
|
scale=self.head_k_dim**-0.5,
|
||||||
|
initial_state=ssm_state,
|
||||||
|
out=out_buf,
|
||||||
|
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
|
||||||
|
use_qk_l2norm_in_kernel=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class Qwen3NextAttention(nn.Module):
|
class Qwen3NextAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Reference in New Issue
Block a user