From f26fcdfb9e50fef30381ed27fa956f7a43b0b1aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Thu, 26 Mar 2026 21:01:05 +0200 Subject: [PATCH] [Bugfix][ROCm] Fix lru_cache on paged_mqa_logits_module (#37547) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Stig-Arne Grönroos --- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 77 ++++++++++--------- 1 file changed, 39 insertions(+), 38 deletions(-) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 878ae3aac..9d1da5b53 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -273,6 +273,25 @@ def fp8_paged_mqa_logits_torch( return logits +@functools.lru_cache +def paged_mqa_logits_module(): + paged_mqa_logits_module_path = None + if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None: + paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits" + elif ( + importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None + ): + paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits" + + if paged_mqa_logits_module_path is not None: + try: + module = importlib.import_module(paged_mqa_logits_module_path) + return module + except ImportError: + return None + return None + + def rocm_fp8_paged_mqa_logits( q_fp8: torch.Tensor, kv_cache_fp8: torch.Tensor, @@ -305,25 +324,6 @@ def rocm_fp8_paged_mqa_logits( """ from vllm._aiter_ops import rocm_aiter_ops - @functools.lru_cache - def paged_mqa_logits_module(): - paged_mqa_logits_module_path = None - if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None: - paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits" - elif ( - importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits") - is not None - ): - paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits" - - if paged_mqa_logits_module_path is not None: - try: - module = importlib.import_module(paged_mqa_logits_module_path) - return module - except ImportError: - return None - return None - aiter_paged_mqa_logits_module = None if rocm_aiter_ops.is_enabled(): aiter_paged_mqa_logits_module = paged_mqa_logits_module() @@ -400,6 +400,26 @@ def fp8_mqa_logits_torch( return logits +@functools.lru_cache +def mqa_logits_module(): + mqa_logits_module_path = None + if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None: + mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits" + elif ( + importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits") + is not None + ): + mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits" + + if mqa_logits_module_path is not None: + try: + module = importlib.import_module(mqa_logits_module_path) + return module + except ImportError: + return None + return None + + def rocm_fp8_mqa_logits( q: torch.Tensor, kv: tuple[torch.Tensor, torch.Tensor], @@ -429,25 +449,6 @@ def rocm_fp8_mqa_logits( # path after aiter merge this kernel into main from vllm._aiter_ops import rocm_aiter_ops - @functools.lru_cache - def mqa_logits_module(): - mqa_logits_module_path = None - if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None: - mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits" - elif ( - importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits") - is not None - ): - mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits" - - if mqa_logits_module_path is not None: - try: - module = importlib.import_module(mqa_logits_module_path) - return module - except ImportError: - return None - return None - aiter_mqa_logits_module = None if rocm_aiter_ops.is_enabled(): aiter_mqa_logits_module = mqa_logits_module()