[CI][Hardware][AMD] Fix test_rotary_embedding_mla_cache_fused (#32408)
Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
@@ -13,6 +13,7 @@ from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
|
||||
@@ -68,9 +69,17 @@ def test_concat_and_cache_mla_rope_fused(
|
||||
k_pe = torch.flatten(key[..., :qk_rope_head_dim], start_dim=1).to(device=device)
|
||||
kv_c = torch.flatten(key[..., qk_rope_head_dim:], start_dim=1).to(device=device)
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_q_pe, ref_k_pe = rope.forward_native(positions, query, k_pe)
|
||||
if current_platform.is_rocm():
|
||||
# We use forward_hip for the same numerics as the fused custom kernel on ROCm
|
||||
# when dtype is FP16. The torch-native implementation implicitly upcasts
|
||||
# FP16 x FP16 multiplications to FP32 before downcasting them, which leads
|
||||
# to notable output divergences.
|
||||
# Clone the tensors because the implementation modifies them in-place
|
||||
ref_q_pe, ref_k_pe = rope.forward_hip(positions, query.clone(), k_pe.clone())
|
||||
else:
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_q_pe, ref_k_pe = rope.forward_native(positions, query, k_pe)
|
||||
assert ref_k_pe is not None
|
||||
|
||||
ref_k_pe = torch.flatten(ref_k_pe, start_dim=1).to(device=device)
|
||||
|
||||
Reference in New Issue
Block a user