[Bugfix] Fix KV scales inconsistency in fp8 MLA & FlashInfer kv_cache_dtype "auto" leading to gibberish (#37054)
Signed-off-by: Andy Lo <andy@mistral.ai>
This commit is contained in:
@@ -178,6 +178,7 @@ def _quantize_dequantize_fp8_ds_mla(
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_ds_mla"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
@pytest.mark.parametrize("block_size", [32, 64])
|
||||
@pytest.mark.parametrize(("q_scale", "k_scale"), [(1.0, 1.0), (2.0, 3.0)])
|
||||
def test_sparse_backend_decode_correctness(
|
||||
default_vllm_config,
|
||||
dist_init,
|
||||
@@ -187,6 +188,8 @@ def test_sparse_backend_decode_correctness(
|
||||
tensor_parallel_size,
|
||||
block_size,
|
||||
workspace_init,
|
||||
q_scale: float,
|
||||
k_scale: float,
|
||||
):
|
||||
if kv_cache_dtype not in backend_cls.supported_kv_cache_dtypes:
|
||||
pytest.skip(f"{backend_cls.get_name()} does not support {kv_cache_dtype}")
|
||||
@@ -332,7 +335,7 @@ def test_sparse_backend_decode_correctness(
|
||||
kv_c_contexts, k_pe_contexts = [], []
|
||||
reference_outputs = []
|
||||
|
||||
kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
kv_cache_scale = torch.tensor(k_scale, dtype=torch.float32, device=device)
|
||||
global_token_idx = 0
|
||||
|
||||
for i in range(batch_spec.batch_size):
|
||||
@@ -490,6 +493,8 @@ def test_sparse_backend_decode_correctness(
|
||||
device=device,
|
||||
W_UK=W_UK,
|
||||
W_UV=W_UV,
|
||||
q_scale=q_scale,
|
||||
k_scale=k_scale,
|
||||
)
|
||||
|
||||
out_buffer = torch.empty(
|
||||
@@ -513,7 +518,9 @@ def test_sparse_backend_decode_correctness(
|
||||
# FP8 quantization introduces some error, but should be within reasonable bounds
|
||||
# BF16 (auto) should be very accurate, FP8 allows slightly more tolerance
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.05, atol=0.05)
|
||||
torch.testing.assert_close(
|
||||
backend_output, sdpa_reference, rtol=0.065, atol=0.05
|
||||
)
|
||||
else:
|
||||
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.01, atol=0.01)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user