diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 86efefc37..796912a68 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -266,22 +266,6 @@ def create_and_prepopulate_kv_cache( return kv_cache -class MockAttentionLayer: - """A mock attention layer for testing.""" - - def __init__(self, device: torch.device): - self._q_scale = torch.tensor(1.0, device=device) - self._k_scale = torch.tensor(1.0, device=device) - self._v_scale = torch.tensor(1.0, device=device) - self._prob_scale = torch.tensor(1.0, device=device) - self._q_scale_float = 1.0 - self._k_scale_float = 1.0 - self._v_scale_float = 1.0 - - def forward(self, *_args, **_kwargs): - raise NotImplementedError - - class MockSparseMLAAttentionLayer: """A mock sparse MLA attention layer for testing. @@ -304,6 +288,8 @@ class MockSparseMLAAttentionLayer: device: torch.device, W_UK: torch.Tensor, W_UV: torch.Tensor, + q_scale: float, + k_scale: float, ): self.impl = impl self.num_heads = num_heads @@ -319,13 +305,13 @@ class MockSparseMLAAttentionLayer: self.W_UV = W_UV.transpose(0, 1) # Scale attributes needed by attention backends - self._q_scale = torch.tensor(1.0, device=device) - self._k_scale = torch.tensor(1.0, device=device) - self._v_scale = torch.tensor(1.0, device=device) + self._q_scale = torch.tensor(q_scale, device=device) + self._k_scale = torch.tensor(k_scale, device=device) + self._v_scale = torch.tensor(float("nan"), device=device) self._prob_scale = torch.tensor(1.0, device=device) - self._q_scale_float = 1.0 - self._k_scale_float = 1.0 - self._v_scale_float = 1.0 + self._q_scale_float = q_scale + self._k_scale_float = k_scale + self._v_scale_float = float("nan") self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( static=True, @@ -420,6 +406,8 @@ class MockMLAAttentionLayer(AttentionLayerBase): kv_lora_rank: int, device: torch.device, kv_b_proj, + q_scale: float, + k_scale: float, ): self.impl = impl self.num_heads = num_heads @@ -443,13 +431,13 @@ class MockMLAAttentionLayer(AttentionLayerBase): self.W_UK_T = W_UK.permute(1, 2, 0) # Scale attributes needed by attention backends - self._q_scale = torch.tensor(1.0, device=device) - self._k_scale = torch.tensor(1.0, device=device) - self._v_scale = torch.tensor(1.0, device=device) + self._q_scale = torch.tensor(q_scale, device=device) + self._k_scale = torch.tensor(k_scale, device=device) + self._v_scale = torch.tensor(float("nan"), device=device) self._prob_scale = torch.tensor(1.0, device=device) - self._q_scale_float = 1.0 - self._k_scale_float = 1.0 - self._v_scale_float = 1.0 + self._q_scale_float = q_scale + self._k_scale_float = k_scale + self._v_scale_float = float("nan") self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( static=True, @@ -568,6 +556,8 @@ def run_attention_backend( qk_rope_head_dim: int, v_head_dim: int, mock_kv_b_proj, + q_scale: float, + k_scale: float, kv_cache_dtype: str = "auto", ) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" @@ -625,6 +615,8 @@ def run_attention_backend( kv_lora_rank=kv_lora_rank, device=device, kv_b_proj=mock_kv_b_proj, + q_scale=q_scale, + k_scale=k_scale, ) # Populate static_forward_context with mock attention layers @@ -674,6 +666,7 @@ def run_attention_backend( @pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"]) @pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16]) @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"]) +@pytest.mark.parametrize(("q_scale", "k_scale"), [(1.0, 1.0), (2.0, 3.0)]) def test_backend_correctness( default_vllm_config, dist_init, @@ -681,6 +674,8 @@ def test_backend_correctness( model: str, tensor_parallel_size: int, kv_cache_dtype: str, + q_scale: float, + k_scale: float, ): """ Test that all backends produce similar outputs to a reference implementation @@ -709,6 +704,11 @@ def test_backend_correctness( for b in BACKENDS_TO_TEST if kv_cache_dtype in b.get_class().supported_kv_cache_dtypes ] + if ( + q_scale != 1.0 or k_scale != 1.0 + ) and AttentionBackendEnum.CUTLASS_MLA in backends_to_test: + # CUTLASS_MLA does not support non-1 Q/K scales + backends_to_test.remove(AttentionBackendEnum.CUTLASS_MLA) if not backends_to_test: pytest.skip(f"No backends support kv_cache_dtype={kv_cache_dtype}") @@ -1029,6 +1029,7 @@ def test_backend_correctness( common_attn_metadata=common_attn_metadata, randomize_blocks=True, kv_cache_dtype=kv_cache_dtype, + scale=k_scale, ) kv_cache_per_block_size[block_size] = kv_cache @@ -1072,6 +1073,8 @@ def test_backend_correctness( qk_rope_head_dim, v_head_dim, mock_kv_b_proj, + q_scale=q_scale, + k_scale=k_scale, kv_cache_dtype=kv_cache_dtype, ) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 0fd0ba6fa..3f6faf51d 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -178,6 +178,7 @@ def _quantize_dequantize_fp8_ds_mla( @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_ds_mla"]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) @pytest.mark.parametrize("block_size", [32, 64]) +@pytest.mark.parametrize(("q_scale", "k_scale"), [(1.0, 1.0), (2.0, 3.0)]) def test_sparse_backend_decode_correctness( default_vllm_config, dist_init, @@ -187,6 +188,8 @@ def test_sparse_backend_decode_correctness( tensor_parallel_size, block_size, workspace_init, + q_scale: float, + k_scale: float, ): if kv_cache_dtype not in backend_cls.supported_kv_cache_dtypes: pytest.skip(f"{backend_cls.get_name()} does not support {kv_cache_dtype}") @@ -332,7 +335,7 @@ def test_sparse_backend_decode_correctness( kv_c_contexts, k_pe_contexts = [], [] reference_outputs = [] - kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + kv_cache_scale = torch.tensor(k_scale, dtype=torch.float32, device=device) global_token_idx = 0 for i in range(batch_spec.batch_size): @@ -490,6 +493,8 @@ def test_sparse_backend_decode_correctness( device=device, W_UK=W_UK, W_UV=W_UV, + q_scale=q_scale, + k_scale=k_scale, ) out_buffer = torch.empty( @@ -513,7 +518,9 @@ def test_sparse_backend_decode_correctness( # FP8 quantization introduces some error, but should be within reasonable bounds # BF16 (auto) should be very accurate, FP8 allows slightly more tolerance if kv_cache_dtype.startswith("fp8"): - torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.05, atol=0.05) + torch.testing.assert_close( + backend_output, sdpa_reference, rtol=0.065, atol=0.05 + ) else: torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.01, atol=0.01) diff --git a/tests/v1/attention/test_trtllm_attention_integration.py b/tests/v1/attention/test_trtllm_attention_integration.py index 50a2c8625..113442bf6 100644 --- a/tests/v1/attention/test_trtllm_attention_integration.py +++ b/tests/v1/attention/test_trtllm_attention_integration.py @@ -43,12 +43,12 @@ class MockAttentionLayer: """Minimal mock of an attention layer for testing.""" def __init__(self, device: torch.device): - self._q_scale = torch.tensor(1.0, device=device) - self._k_scale = torch.tensor(1.0, device=device) - self._v_scale = torch.tensor(1.0, device=device) - self._q_scale_float = 1.0 - self._k_scale_float = 1.0 - self._v_scale_float = 1.0 + self._q_scale = torch.tensor(2.0, device=device) + self._k_scale = torch.tensor(3.0, device=device) + self._v_scale = torch.tensor(4.0, device=device) + self._q_scale_float = 2.0 + self._k_scale_float = 3.0 + self._v_scale_float = 4.0 self._o_scale_float = None diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 411ec746c..da97f612a 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1319,10 +1319,14 @@ class FlashInferImpl(AttentionImpl): ) if self.bmm1_scale is None: - self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale + self.bmm1_scale = self.scale + if self.kv_cache_dtype.startswith("fp8"): + self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float if self.bmm2_scale is None: - self.bmm2_scale = layer._v_scale_float + self.bmm2_scale = 1.0 + if self.kv_cache_dtype.startswith("fp8"): + self.bmm2_scale *= layer._v_scale_float prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill) decode_use_trtllm = isinstance(attn_metadata.decode, TRTLLMDecode) diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 19faf3c93..b01ce2be2 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -255,6 +255,11 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None + if layer._q_scale_float != 1.0 or layer._k_scale_float != 1.0: + raise NotImplementedError( + "CutlassMLAImpl does not support scaling for q and kv_latent yet" + ) + if type(q) is tuple: q_nope, q_pe = q else: diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index ec8f4e640..c2ce8ac5b 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -177,9 +177,14 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1]) if self.bmm1_scale is None: - self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale + self.bmm1_scale = self.scale + if self.kv_cache_dtype.startswith("fp8"): + self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float + if self.bmm2_scale is None: - self.bmm2_scale = layer._v_scale_float + self.bmm2_scale = 1.0 + if self.kv_cache_dtype.startswith("fp8"): + self.bmm2_scale *= layer._k_scale_float o = trtllm_batch_decode_with_kv_cache_mla( query=q, diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py index 7f334bf01..9554457b4 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py @@ -340,9 +340,13 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata self._workspace_buffer = _get_workspace_buffer(q.device) if self.bmm1_scale is None: - self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale + self.bmm1_scale = self.scale + if self.kv_cache_dtype.startswith("fp8"): + self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float if self.bmm2_scale is None: - self.bmm2_scale = layer._v_scale_float + self.bmm2_scale = 1.0 + if self.kv_cache_dtype.startswith("fp8"): + self.bmm2_scale *= layer._k_scale_float o = trtllm_batch_decode_with_kv_cache_mla( query=q.unsqueeze(1), diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index d1b007a80..b205066d6 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -187,7 +187,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): self.scale, PAGE_SIZE, k_scale=layer._k_scale, - v_scale=layer._v_scale, + v_scale=layer._k_scale, ) return o, lse