diff --git a/tests/kernels/core/test_rotary_embedding_mla_cache_fused.py b/tests/kernels/core/test_rotary_embedding_mla_cache_fused.py index 021171d88..a8781afd8 100644 --- a/tests/kernels/core/test_rotary_embedding_mla_cache_fused.py +++ b/tests/kernels/core/test_rotary_embedding_mla_cache_fused.py @@ -13,6 +13,7 @@ from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from vllm import _custom_ops as ops from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed @@ -68,9 +69,17 @@ def test_concat_and_cache_mla_rope_fused( k_pe = torch.flatten(key[..., :qk_rope_head_dim], start_dim=1).to(device=device) kv_c = torch.flatten(key[..., qk_rope_head_dim:], start_dim=1).to(device=device) - # NOTE(woosuk): The reference implementation should be executed first - # because the custom kernel is in-place. - ref_q_pe, ref_k_pe = rope.forward_native(positions, query, k_pe) + if current_platform.is_rocm(): + # We use forward_hip for the same numerics as the fused custom kernel on ROCm + # when dtype is FP16. The torch-native implementation implicitly upcasts + # FP16 x FP16 multiplications to FP32 before downcasting them, which leads + # to notable output divergences. + # Clone the tensors because the implementation modifies them in-place + ref_q_pe, ref_k_pe = rope.forward_hip(positions, query.clone(), k_pe.clone()) + else: + # NOTE(woosuk): The reference implementation should be executed first + # because the custom kernel is in-place. + ref_q_pe, ref_k_pe = rope.forward_native(positions, query, k_pe) assert ref_k_pe is not None ref_k_pe = torch.flatten(ref_k_pe, start_dim=1).to(device=device)