# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch import torch.nn.functional as F from vllm.model_executor.layers.fla.ops import ( fused_recurrent_gated_delta_rule, fused_sigmoid_gating_delta_rule_update, ) from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed DEVICE = current_platform.device_type @pytest.mark.parametrize("tp_size", [1]) @pytest.mark.parametrize("num_reqs", [1, 2, 4]) @pytest.mark.parametrize("num_k_heads", [16]) @pytest.mark.parametrize("num_v_heads", [32]) @pytest.mark.parametrize("head_k_dim", [128]) @pytest.mark.parametrize("head_v_dim", [128]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) def test_fused_sigmoid_gating_delta_rule_update_non_spec( tp_size: int, num_reqs: int, num_k_heads: int, num_v_heads: int, head_k_dim: int, head_v_dim: int, dtype: torch.dtype, ) -> None: torch.set_default_device(DEVICE) set_random_seed(0) key_dim = head_k_dim * num_k_heads value_dim = head_v_dim * num_v_heads mixed_qkv_dim = (key_dim * 2 + value_dim) // tp_size seq_len = 1 # seq_len is 1 for decode num_tokens = num_reqs * seq_len total_entries = num_tokens * 2 mixed_qkv = torch.rand(num_tokens, mixed_qkv_dim, dtype=dtype) query, key, value = torch.split( mixed_qkv, [ key_dim // tp_size, key_dim // tp_size, value_dim // tp_size, ], dim=-1, ) query = query.view(1, num_tokens, num_k_heads, head_k_dim) key = key.view(1, num_tokens, num_k_heads, head_k_dim) value = value.view(1, num_tokens, num_v_heads, head_v_dim) A_log = torch.rand(num_v_heads // tp_size, dtype=dtype) dt_bias = torch.rand(num_v_heads // tp_size, dtype=dtype) a = torch.rand(num_tokens, num_v_heads, dtype=dtype) b = torch.rand(num_tokens, num_v_heads, dtype=dtype) ssm_state = torch.rand( total_entries, num_v_heads, head_k_dim, head_v_dim, dtype=dtype ) state_indices = torch.randperm(total_entries, dtype=torch.int32)[:num_tokens] cu_seqlens = torch.arange(0, num_tokens + 1, dtype=torch.int32) beta = b.sigmoid() g = -A_log.float().exp() * F.softplus(a.float() + dt_bias) core_attn_out_ref, last_recurrent_state_ref = fused_recurrent_gated_delta_rule( q=query, k=key, v=value, g=g.unsqueeze(0), beta=beta.unsqueeze(0), initial_state=ssm_state.clone(), inplace_final_state=True, ssm_state_indices=state_indices, cu_seqlens=cu_seqlens, use_qk_l2norm_in_kernel=True, ) core_attn_out, last_recurrent_state = fused_sigmoid_gating_delta_rule_update( A_log=A_log, a=a, b=b, dt_bias=dt_bias, q=query, k=key, v=value, initial_state=ssm_state, inplace_final_state=True, ssm_state_indices=state_indices, cu_seqlens=cu_seqlens, use_qk_l2norm_in_kernel=True, ) torch.testing.assert_close(core_attn_out, core_attn_out_ref, atol=1e-2, rtol=1e-2) torch.testing.assert_close( last_recurrent_state, last_recurrent_state_ref, atol=1e-2, rtol=1e-2 ) @pytest.mark.parametrize("tp_size", [1]) @pytest.mark.parametrize("num_reqs", [1, 2, 4]) @pytest.mark.parametrize("num_k_heads", [16]) @pytest.mark.parametrize("num_v_heads", [32]) @pytest.mark.parametrize("head_k_dim", [128]) @pytest.mark.parametrize("head_v_dim", [128]) @pytest.mark.parametrize("num_speculative_tokens", [1, 3]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) def test_fused_sigmoid_gating_delta_rule_update_spec( tp_size: int, num_reqs: int, num_k_heads: int, num_v_heads: int, head_k_dim: int, head_v_dim: int, num_speculative_tokens: int, dtype: torch.dtype, ) -> None: torch.set_default_device(DEVICE) set_random_seed(0) key_dim = head_k_dim * num_k_heads value_dim = head_v_dim * num_v_heads mixed_qkv_dim = (key_dim * 2 + value_dim) // tp_size num_tokens = num_reqs * (num_speculative_tokens + 1) total_entries = num_tokens * 2 mixed_qkv = torch.rand(num_tokens, mixed_qkv_dim, dtype=dtype) query, key, value = torch.split( mixed_qkv, [ key_dim // tp_size, key_dim // tp_size, value_dim // tp_size, ], dim=-1, ) query = query.view(1, num_tokens, num_k_heads, head_k_dim) key = key.view(1, num_tokens, num_k_heads, head_k_dim) value = value.view(1, num_tokens, num_v_heads, head_v_dim) A_log = torch.rand(num_v_heads // tp_size, dtype=dtype) dt_bias = torch.rand(num_v_heads // tp_size, dtype=dtype) a = torch.rand(num_tokens, num_v_heads, dtype=dtype) b = torch.rand(num_tokens, num_v_heads, dtype=dtype) ssm_state = torch.rand( total_entries, num_v_heads, head_k_dim, head_v_dim, dtype=dtype ) state_indices = torch.randperm( total_entries, dtype=torch.int32, )[:num_tokens].view(num_reqs, num_speculative_tokens + 1) num_accepted_tokens = torch.randint( 1, num_speculative_tokens + 1, (num_reqs,), dtype=torch.int32 ) cu_seqlens = torch.arange( 0, num_tokens + 1, num_speculative_tokens + 1, dtype=torch.int32 ) beta = b.sigmoid() g = -A_log.float().exp() * F.softplus(a.float() + dt_bias) core_attn_out_ref, last_recurrent_state_ref = fused_recurrent_gated_delta_rule( q=query, k=key, v=value, g=g.unsqueeze(0), beta=beta.unsqueeze(0), initial_state=ssm_state.clone(), inplace_final_state=True, ssm_state_indices=state_indices, cu_seqlens=cu_seqlens, num_accepted_tokens=num_accepted_tokens, use_qk_l2norm_in_kernel=True, ) core_attn_out, last_recurrent_state = fused_sigmoid_gating_delta_rule_update( A_log=A_log, a=a, b=b, dt_bias=dt_bias, q=query, k=key, v=value, initial_state=ssm_state, inplace_final_state=True, ssm_state_indices=state_indices, cu_seqlens=cu_seqlens, num_accepted_tokens=num_accepted_tokens, use_qk_l2norm_in_kernel=True, ) torch.testing.assert_close(core_attn_out, core_attn_out_ref, atol=1e-2, rtol=1e-2) torch.testing.assert_close( last_recurrent_state, last_recurrent_state_ref, atol=1e-2, rtol=1e-2 )