[Bugfix][ROCm] Fix lru_cache on paged_mqa_logits_module (#37547)
Signed-off-by: Stig-Arne Grönroos <stig-arne.gronroos@amd.com>
This commit is contained in:
committed by
GitHub
parent
bc9c6fbbe6
commit
f26fcdfb9e
@@ -273,6 +273,25 @@ def fp8_paged_mqa_logits_torch(
|
||||
return logits
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def paged_mqa_logits_module():
|
||||
paged_mqa_logits_module_path = None
|
||||
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
|
||||
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
|
||||
elif (
|
||||
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None
|
||||
):
|
||||
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
|
||||
|
||||
if paged_mqa_logits_module_path is not None:
|
||||
try:
|
||||
module = importlib.import_module(paged_mqa_logits_module_path)
|
||||
return module
|
||||
except ImportError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def rocm_fp8_paged_mqa_logits(
|
||||
q_fp8: torch.Tensor,
|
||||
kv_cache_fp8: torch.Tensor,
|
||||
@@ -305,25 +324,6 @@ def rocm_fp8_paged_mqa_logits(
|
||||
"""
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
@functools.lru_cache
|
||||
def paged_mqa_logits_module():
|
||||
paged_mqa_logits_module_path = None
|
||||
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
|
||||
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
|
||||
elif (
|
||||
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits")
|
||||
is not None
|
||||
):
|
||||
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
|
||||
|
||||
if paged_mqa_logits_module_path is not None:
|
||||
try:
|
||||
module = importlib.import_module(paged_mqa_logits_module_path)
|
||||
return module
|
||||
except ImportError:
|
||||
return None
|
||||
return None
|
||||
|
||||
aiter_paged_mqa_logits_module = None
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
|
||||
@@ -400,6 +400,26 @@ def fp8_mqa_logits_torch(
|
||||
return logits
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def mqa_logits_module():
|
||||
mqa_logits_module_path = None
|
||||
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
|
||||
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
|
||||
elif (
|
||||
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
|
||||
is not None
|
||||
):
|
||||
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
|
||||
|
||||
if mqa_logits_module_path is not None:
|
||||
try:
|
||||
module = importlib.import_module(mqa_logits_module_path)
|
||||
return module
|
||||
except ImportError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def rocm_fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -429,25 +449,6 @@ def rocm_fp8_mqa_logits(
|
||||
# path after aiter merge this kernel into main
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
@functools.lru_cache
|
||||
def mqa_logits_module():
|
||||
mqa_logits_module_path = None
|
||||
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
|
||||
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
|
||||
elif (
|
||||
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
|
||||
is not None
|
||||
):
|
||||
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
|
||||
|
||||
if mqa_logits_module_path is not None:
|
||||
try:
|
||||
module = importlib.import_module(mqa_logits_module_path)
|
||||
return module
|
||||
except ImportError:
|
||||
return None
|
||||
return None
|
||||
|
||||
aiter_mqa_logits_module = None
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
aiter_mqa_logits_module = mqa_logits_module()
|
||||
|
||||
Reference in New Issue
Block a user