[Perf] add packed recurrent fast path for decode (#36596)
Signed-off-by: hdj <1293066020@qq.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
98
tests/kernels/test_fused_recurrent_packed_decode.py
Normal file
98
tests/kernels/test_fused_recurrent_packed_decode.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fla.ops import (
|
||||
fused_recurrent_gated_delta_rule,
|
||||
fused_recurrent_gated_delta_rule_packed_decode,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need CUDA device")
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("strided_mixed_qkv", [False, True])
|
||||
def test_fused_recurrent_packed_decode_matches_reference(
|
||||
dtype: torch.dtype, strided_mixed_qkv: bool
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Small but representative GDN config (Qwen3Next defaults are K=128, V=128).
|
||||
B = 32
|
||||
H = 4
|
||||
HV = 8 # grouped value attention: HV must be divisible by H
|
||||
K = 128
|
||||
V = 128
|
||||
qkv_dim = 2 * (H * K) + (HV * V)
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
if strided_mixed_qkv:
|
||||
# Simulate a packed view into a larger projection buffer:
|
||||
# mixed_qkv.stride(0) > mixed_qkv.shape[1]
|
||||
proj = torch.randn((B, qkv_dim + 64), device=device, dtype=dtype)
|
||||
mixed_qkv = proj[:, :qkv_dim]
|
||||
else:
|
||||
mixed_qkv = torch.randn((B, qkv_dim), device=device, dtype=dtype)
|
||||
|
||||
a = torch.randn((B, HV), device=device, dtype=dtype)
|
||||
b = torch.randn((B, HV), device=device, dtype=dtype)
|
||||
A_log = torch.randn((HV,), device=device, dtype=dtype)
|
||||
dt_bias = torch.randn((HV,), device=device, dtype=dtype)
|
||||
|
||||
# Continuous batching indices (include PAD_SLOT_ID=-1 cases).
|
||||
ssm_state_indices = torch.arange(B, device=device, dtype=torch.int32)
|
||||
ssm_state_indices[-3:] = -1
|
||||
|
||||
state0 = torch.randn((B, HV, V, K), device=device, dtype=dtype)
|
||||
state_ref = state0.clone()
|
||||
state_packed = state0.clone()
|
||||
|
||||
out_packed = torch.empty((B, 1, HV, V), device=device, dtype=dtype)
|
||||
|
||||
# Reference path: materialize contiguous Q/K/V + explicit gating.
|
||||
q, k, v = torch.split(mixed_qkv, [H * K, H * K, HV * V], dim=-1)
|
||||
q = q.view(B, H, K).unsqueeze(1).contiguous()
|
||||
k = k.view(B, H, K).unsqueeze(1).contiguous()
|
||||
v = v.view(B, HV, V).unsqueeze(1).contiguous()
|
||||
|
||||
x = a.float() + dt_bias.float()
|
||||
softplus_x = torch.where(
|
||||
x <= 20.0, torch.log1p(torch.exp(torch.clamp(x, max=20.0))), x
|
||||
)
|
||||
g = (-torch.exp(A_log.float()) * softplus_x).unsqueeze(1)
|
||||
beta = torch.sigmoid(b.float()).to(dtype).unsqueeze(1)
|
||||
|
||||
out_ref, state_ref = fused_recurrent_gated_delta_rule(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=K**-0.5,
|
||||
initial_state=state_ref,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=None,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
# Packed path: fused gating + recurrent directly from packed mixed_qkv.
|
||||
fused_recurrent_gated_delta_rule_packed_decode(
|
||||
mixed_qkv=mixed_qkv,
|
||||
a=a,
|
||||
b=b,
|
||||
A_log=A_log,
|
||||
dt_bias=dt_bias,
|
||||
scale=K**-0.5,
|
||||
initial_state=state_packed,
|
||||
out=out_packed,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
atol = 2e-2 if dtype != torch.float32 else 1e-4
|
||||
rtol = 1e-2 if dtype != torch.float32 else 1e-4
|
||||
torch.testing.assert_close(out_packed, out_ref, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(state_packed, state_ref, rtol=rtol, atol=atol)
|
||||
@@ -96,6 +96,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||
VLLM_SKIP_P2P_CHECK: bool = False
|
||||
VLLM_DISABLED_KERNELS: list[str] = []
|
||||
VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE: bool = True
|
||||
VLLM_DISABLE_PYNCCL: bool = False
|
||||
VLLM_USE_OINK_OPS: bool = False
|
||||
VLLM_ROCM_USE_AITER: bool = False
|
||||
@@ -899,6 +900,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_DISABLED_KERNELS": lambda: []
|
||||
if "VLLM_DISABLED_KERNELS" not in os.environ
|
||||
else os.environ["VLLM_DISABLED_KERNELS"].split(","),
|
||||
"VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE": lambda: bool(
|
||||
int(os.getenv("VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE", "1"))
|
||||
),
|
||||
# Disable pynccl (using torch.distributed instead)
|
||||
"VLLM_DISABLE_PYNCCL": lambda: (
|
||||
os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")
|
||||
|
||||
@@ -7,7 +7,10 @@
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
from .chunk import chunk_gated_delta_rule
|
||||
from .fused_recurrent import fused_recurrent_gated_delta_rule
|
||||
from .fused_recurrent import (
|
||||
fused_recurrent_gated_delta_rule,
|
||||
fused_recurrent_gated_delta_rule_packed_decode,
|
||||
)
|
||||
from .fused_sigmoid_gating import fused_sigmoid_gating_delta_rule_update
|
||||
from .layernorm_guard import RMSNormGated
|
||||
|
||||
@@ -15,5 +18,6 @@ __all__ = [
|
||||
"RMSNormGated",
|
||||
"chunk_gated_delta_rule",
|
||||
"fused_recurrent_gated_delta_rule",
|
||||
"fused_recurrent_gated_delta_rule_packed_decode",
|
||||
"fused_sigmoid_gating_delta_rule_update",
|
||||
]
|
||||
|
||||
@@ -252,6 +252,231 @@ def fused_recurrent_gated_delta_rule_fwd(
|
||||
return o, final_state
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_gated_delta_rule_packed_decode_kernel(
|
||||
mixed_qkv,
|
||||
a,
|
||||
b,
|
||||
A_log,
|
||||
dt_bias,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
ssm_state_indices,
|
||||
scale,
|
||||
stride_mixed_qkv_tok: tl.constexpr,
|
||||
stride_a_tok: tl.constexpr,
|
||||
stride_b_tok: tl.constexpr,
|
||||
stride_init_state_token: tl.constexpr,
|
||||
stride_final_state_token: tl.constexpr,
|
||||
stride_indices_seq: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
SOFTPLUS_THRESHOLD: tl.constexpr,
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
):
|
||||
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
|
||||
o_k = tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_v[:, None] & mask_k[None, :]
|
||||
|
||||
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64)
|
||||
p_o = o + (i_n * HV + i_hv) * V + o_v
|
||||
|
||||
if state_idx < 0:
|
||||
zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty)
|
||||
tl.store(p_o, zero, mask=mask_v)
|
||||
return
|
||||
|
||||
p_h0 = h0 + state_idx * stride_init_state_token
|
||||
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
b_h = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
p_mixed = mixed_qkv + i_n * stride_mixed_qkv_tok
|
||||
q_off = i_h * K + o_k
|
||||
k_off = (H * K) + i_h * K + o_k
|
||||
v_off = (2 * H * K) + i_hv * V + o_v
|
||||
b_q = tl.load(p_mixed + q_off, mask=mask_k, other=0).to(tl.float32)
|
||||
b_k = tl.load(p_mixed + k_off, mask=mask_k, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_mixed + v_off, mask=mask_v, other=0).to(tl.float32)
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
||||
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
||||
b_q = b_q * scale
|
||||
|
||||
a_val = tl.load(a + i_n * stride_a_tok + i_hv).to(tl.float32)
|
||||
b_val = tl.load(b + i_n * stride_b_tok + i_hv).to(tl.float32)
|
||||
A_log_val = tl.load(A_log + i_hv).to(tl.float32)
|
||||
dt_bias_val = tl.load(dt_bias + i_hv).to(tl.float32)
|
||||
x = a_val + dt_bias_val
|
||||
softplus_x = tl.where(x <= SOFTPLUS_THRESHOLD, tl.log(1.0 + tl.exp(x)), x)
|
||||
g_val = -tl.exp(A_log_val) * softplus_x
|
||||
beta_val = tl.sigmoid(b_val).to(b.dtype.element_ty).to(tl.float32)
|
||||
|
||||
b_h *= exp(g_val)
|
||||
b_v -= tl.sum(b_h * b_k[None, :], 1)
|
||||
b_v *= beta_val
|
||||
b_h += b_v[:, None] * b_k[None, :]
|
||||
b_o = tl.sum(b_h * b_q[None, :], 1)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
p_ht = ht + state_idx * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_packed_decode(
|
||||
mixed_qkv: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
A_log: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
ssm_state_indices: torch.Tensor,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if mixed_qkv.ndim != 2:
|
||||
raise ValueError(
|
||||
f"`mixed_qkv` must be a 2D tensor (got ndim={mixed_qkv.ndim})."
|
||||
)
|
||||
if mixed_qkv.stride(-1) != 1:
|
||||
raise ValueError("`mixed_qkv` must be contiguous in the last dim.")
|
||||
if a.ndim != 2 or b.ndim != 2:
|
||||
raise ValueError(
|
||||
f"`a` and `b` must be 2D tensors (got a.ndim={a.ndim}, b.ndim={b.ndim})."
|
||||
)
|
||||
if a.stride(-1) != 1 or b.stride(-1) != 1:
|
||||
raise ValueError("`a`/`b` must be contiguous in the last dim.")
|
||||
if A_log.ndim != 1 or dt_bias.ndim != 1:
|
||||
raise ValueError("`A_log`/`dt_bias` must be 1D tensors.")
|
||||
if A_log.stride(0) != 1 or dt_bias.stride(0) != 1:
|
||||
raise ValueError("`A_log`/`dt_bias` must be contiguous.")
|
||||
if ssm_state_indices.ndim != 1:
|
||||
raise ValueError(
|
||||
f"`ssm_state_indices` must be 1D for packed decode (got ndim={ssm_state_indices.ndim})."
|
||||
)
|
||||
if not out.is_contiguous():
|
||||
raise ValueError("`out` must be contiguous.")
|
||||
|
||||
dev = mixed_qkv.device
|
||||
if (
|
||||
a.device != dev
|
||||
or b.device != dev
|
||||
or A_log.device != dev
|
||||
or dt_bias.device != dev
|
||||
or initial_state.device != dev
|
||||
or out.device != dev
|
||||
or ssm_state_indices.device != dev
|
||||
):
|
||||
raise ValueError("All inputs must be on the same device.")
|
||||
|
||||
B = mixed_qkv.shape[0]
|
||||
if a.shape[0] != B or b.shape[0] != B:
|
||||
raise ValueError(
|
||||
"Mismatched batch sizes: "
|
||||
f"mixed_qkv.shape[0]={B}, a.shape[0]={a.shape[0]}, b.shape[0]={b.shape[0]}."
|
||||
)
|
||||
if ssm_state_indices.shape[0] != B:
|
||||
raise ValueError(
|
||||
f"`ssm_state_indices` must have shape [B] (got {tuple(ssm_state_indices.shape)}; expected ({B},))."
|
||||
)
|
||||
|
||||
if initial_state.ndim != 4:
|
||||
raise ValueError(
|
||||
f"`initial_state` must be a 4D tensor (got ndim={initial_state.ndim})."
|
||||
)
|
||||
if initial_state.stride(-1) != 1:
|
||||
raise ValueError("`initial_state` must be contiguous in the last dim.")
|
||||
HV, V, K = initial_state.shape[-3:]
|
||||
if a.shape[1] != HV or b.shape[1] != HV:
|
||||
raise ValueError(
|
||||
f"`a`/`b` must have shape [B, HV] with HV={HV} (got a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)})."
|
||||
)
|
||||
if A_log.numel() != HV or dt_bias.numel() != HV:
|
||||
raise ValueError(
|
||||
f"`A_log` and `dt_bias` must have {HV} elements (got A_log.numel()={A_log.numel()}, dt_bias.numel()={dt_bias.numel()})."
|
||||
)
|
||||
if out.shape != (B, 1, HV, V):
|
||||
raise ValueError(
|
||||
f"`out` must have shape {(B, 1, HV, V)} (got out.shape={tuple(out.shape)})."
|
||||
)
|
||||
|
||||
qkv_dim = mixed_qkv.shape[1]
|
||||
qk_dim = qkv_dim - HV * V
|
||||
if qk_dim <= 0 or qk_dim % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Invalid packed `mixed_qkv` last dim={qkv_dim} for HV={HV}, V={V}."
|
||||
)
|
||||
q_dim = qk_dim // 2
|
||||
if q_dim % K != 0:
|
||||
raise ValueError(f"Invalid packed Q size {q_dim}: must be divisible by K={K}.")
|
||||
H = q_dim // K
|
||||
if H <= 0 or HV % H != 0:
|
||||
raise ValueError(
|
||||
f"Invalid head config inferred from mixed_qkv: H={H}, HV={HV}."
|
||||
)
|
||||
|
||||
BK = triton.next_power_of_2(K)
|
||||
if triton.cdiv(K, BK) != 1:
|
||||
raise ValueError(
|
||||
f"Packed decode kernel only supports NK=1 (got K={K}, BK={BK})."
|
||||
)
|
||||
BV = min(triton.next_power_of_2(V), 32)
|
||||
num_stages = 3
|
||||
num_warps = 1
|
||||
|
||||
stride_mixed_qkv_tok = mixed_qkv.stride(0)
|
||||
stride_a_tok = a.stride(0)
|
||||
stride_b_tok = b.stride(0)
|
||||
stride_init_state_token = initial_state.stride(0)
|
||||
stride_final_state_token = initial_state.stride(0)
|
||||
stride_indices_seq = ssm_state_indices.stride(0)
|
||||
|
||||
NV = triton.cdiv(V, BV)
|
||||
grid = (NV, B * HV)
|
||||
fused_recurrent_gated_delta_rule_packed_decode_kernel[grid](
|
||||
mixed_qkv=mixed_qkv,
|
||||
a=a,
|
||||
b=b,
|
||||
A_log=A_log,
|
||||
dt_bias=dt_bias,
|
||||
o=out,
|
||||
h0=initial_state,
|
||||
ht=initial_state,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
scale=scale,
|
||||
stride_mixed_qkv_tok=stride_mixed_qkv_tok,
|
||||
stride_a_tok=stride_a_tok,
|
||||
stride_b_tok=stride_b_tok,
|
||||
stride_init_state_token=stride_init_state_token,
|
||||
stride_final_state_token=stride_final_state_token,
|
||||
stride_indices_seq=stride_indices_seq,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
SOFTPLUS_THRESHOLD=20.0,
|
||||
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
return out, initial_state
|
||||
|
||||
|
||||
class FusedRecurrentFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
|
||||
@@ -10,6 +10,7 @@ from einops import rearrange
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from vllm import envs
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
@@ -34,6 +35,7 @@ from vllm.model_executor.layers.fla.ops import (
|
||||
chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
|
||||
)
|
||||
from vllm.model_executor.layers.fla.ops import (
|
||||
fused_recurrent_gated_delta_rule_packed_decode,
|
||||
fused_sigmoid_gating_delta_rule_update,
|
||||
)
|
||||
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
|
||||
@@ -474,6 +476,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
)
|
||||
|
||||
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
|
||||
self.enable_packed_recurrent_decode = (
|
||||
envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
@@ -747,9 +752,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
a: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Core attention computation (called by custom op).
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
|
||||
@@ -762,6 +764,22 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
||||
|
||||
if (
|
||||
self.enable_packed_recurrent_decode
|
||||
and attn_metadata.spec_sequence_masks is None
|
||||
and attn_metadata.num_prefills == 0
|
||||
and attn_metadata.num_decodes > 0
|
||||
):
|
||||
return self._forward_core_decode_non_spec(
|
||||
mixed_qkv=mixed_qkv,
|
||||
b=b,
|
||||
a=a,
|
||||
core_attn_out=core_attn_out,
|
||||
attn_metadata=attn_metadata,
|
||||
virtual_engine=forward_context.virtual_engine,
|
||||
)
|
||||
|
||||
has_initial_state = attn_metadata.has_initial_state
|
||||
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||
@@ -946,6 +964,55 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
else:
|
||||
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
|
||||
|
||||
def _forward_core_decode_non_spec(
|
||||
self,
|
||||
mixed_qkv: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
attn_metadata: GDNAttentionMetadata,
|
||||
virtual_engine: int,
|
||||
):
|
||||
"""
|
||||
Core attention computation with a packed non-spec decode fast path.
|
||||
"""
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
self_kv_cache = self.kv_cache[virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
mixed_qkv = mixed_qkv[:num_actual_tokens]
|
||||
b = b[:num_actual_tokens]
|
||||
a = a[:num_actual_tokens]
|
||||
|
||||
conv_weights = self.conv1d.weight.view(
|
||||
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||
)
|
||||
mixed_qkv_non_spec = causal_conv1d_update(
|
||||
mixed_qkv,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
|
||||
validate_data=False,
|
||||
)
|
||||
out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
|
||||
fused_recurrent_gated_delta_rule_packed_decode(
|
||||
mixed_qkv=mixed_qkv_non_spec,
|
||||
a=a,
|
||||
b=b,
|
||||
A_log=self.A_log,
|
||||
dt_bias=self.dt_bias,
|
||||
scale=self.head_k_dim**-0.5,
|
||||
initial_state=ssm_state,
|
||||
out=out_buf,
|
||||
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class Qwen3NextAttention(nn.Module):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user