99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.fla.ops import (
|
|
fused_recurrent_gated_delta_rule,
|
|
fused_recurrent_gated_delta_rule_packed_decode,
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA device")
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
|
@pytest.mark.parametrize("strided_mixed_qkv", [False, True])
|
|
def test_fused_recurrent_packed_decode_matches_reference(
|
|
dtype: torch.dtype, strided_mixed_qkv: bool
|
|
):
|
|
torch.manual_seed(0)
|
|
|
|
# Small but representative GDN config (Qwen3Next defaults are K=128, V=128).
|
|
B = 32
|
|
H = 4
|
|
HV = 8 # grouped value attention: HV must be divisible by H
|
|
K = 128
|
|
V = 128
|
|
qkv_dim = 2 * (H * K) + (HV * V)
|
|
|
|
device = torch.device("cuda")
|
|
|
|
if strided_mixed_qkv:
|
|
# Simulate a packed view into a larger projection buffer:
|
|
# mixed_qkv.stride(0) > mixed_qkv.shape[1]
|
|
proj = torch.randn((B, qkv_dim + 64), device=device, dtype=dtype)
|
|
mixed_qkv = proj[:, :qkv_dim]
|
|
else:
|
|
mixed_qkv = torch.randn((B, qkv_dim), device=device, dtype=dtype)
|
|
|
|
a = torch.randn((B, HV), device=device, dtype=dtype)
|
|
b = torch.randn((B, HV), device=device, dtype=dtype)
|
|
A_log = torch.randn((HV,), device=device, dtype=dtype)
|
|
dt_bias = torch.randn((HV,), device=device, dtype=dtype)
|
|
|
|
# Continuous batching indices (include PAD_SLOT_ID=-1 cases).
|
|
ssm_state_indices = torch.arange(B, device=device, dtype=torch.int32)
|
|
ssm_state_indices[-3:] = -1
|
|
|
|
state0 = torch.randn((B, HV, V, K), device=device, dtype=dtype)
|
|
state_ref = state0.clone()
|
|
state_packed = state0.clone()
|
|
|
|
out_packed = torch.empty((B, 1, HV, V), device=device, dtype=dtype)
|
|
|
|
# Reference path: materialize contiguous Q/K/V + explicit gating.
|
|
q, k, v = torch.split(mixed_qkv, [H * K, H * K, HV * V], dim=-1)
|
|
q = q.view(B, H, K).unsqueeze(1).contiguous()
|
|
k = k.view(B, H, K).unsqueeze(1).contiguous()
|
|
v = v.view(B, HV, V).unsqueeze(1).contiguous()
|
|
|
|
x = a.float() + dt_bias.float()
|
|
softplus_x = torch.where(
|
|
x <= 20.0, torch.log1p(torch.exp(torch.clamp(x, max=20.0))), x
|
|
)
|
|
g = (-torch.exp(A_log.float()) * softplus_x).unsqueeze(1)
|
|
beta = torch.sigmoid(b.float()).to(dtype).unsqueeze(1)
|
|
|
|
out_ref, state_ref = fused_recurrent_gated_delta_rule(
|
|
q=q,
|
|
k=k,
|
|
v=v,
|
|
g=g,
|
|
beta=beta,
|
|
scale=K**-0.5,
|
|
initial_state=state_ref,
|
|
inplace_final_state=True,
|
|
cu_seqlens=None,
|
|
ssm_state_indices=ssm_state_indices,
|
|
use_qk_l2norm_in_kernel=True,
|
|
)
|
|
|
|
# Packed path: fused gating + recurrent directly from packed mixed_qkv.
|
|
fused_recurrent_gated_delta_rule_packed_decode(
|
|
mixed_qkv=mixed_qkv,
|
|
a=a,
|
|
b=b,
|
|
A_log=A_log,
|
|
dt_bias=dt_bias,
|
|
scale=K**-0.5,
|
|
initial_state=state_packed,
|
|
out=out_packed,
|
|
ssm_state_indices=ssm_state_indices,
|
|
use_qk_l2norm_in_kernel=True,
|
|
)
|
|
|
|
atol = 2e-2 if dtype != torch.float32 else 1e-4
|
|
rtol = 1e-2 if dtype != torch.float32 else 1e-4
|
|
torch.testing.assert_close(out_packed, out_ref, rtol=rtol, atol=atol)
|
|
torch.testing.assert_close(state_packed, state_ref, rtol=rtol, atol=atol)
|