[CI][Hardware][AMD] Fix test_rotary_embedding_mla_cache_fused (#32408)

Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
Matt
2026-01-19 02:25:47 -06:00
committed by GitHub
parent 3c8740aacb
commit 11bbf86f6a

View File

@@ -13,6 +13,7 @@ from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
@@ -68,9 +69,17 @@ def test_concat_and_cache_mla_rope_fused(
k_pe = torch.flatten(key[..., :qk_rope_head_dim], start_dim=1).to(device=device)
kv_c = torch.flatten(key[..., qk_rope_head_dim:], start_dim=1).to(device=device)
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_q_pe, ref_k_pe = rope.forward_native(positions, query, k_pe)
if current_platform.is_rocm():
# We use forward_hip for the same numerics as the fused custom kernel on ROCm
# when dtype is FP16. The torch-native implementation implicitly upcasts
# FP16 x FP16 multiplications to FP32 before downcasting them, which leads
# to notable output divergences.
# Clone the tensors because the implementation modifies them in-place
ref_q_pe, ref_k_pe = rope.forward_hip(positions, query.clone(), k_pe.clone())
else:
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_q_pe, ref_k_pe = rope.forward_native(positions, query, k_pe)
assert ref_k_pe is not None
ref_k_pe = torch.flatten(ref_k_pe, start_dim=1).to(device=device)