[Bugfix][ROCm] Fix lru_cache on paged_mqa_logits_module (#37547)
Signed-off-by: Stig-Arne Grönroos <stig-arne.gronroos@amd.com>
This commit is contained in:
committed by
GitHub
parent
bc9c6fbbe6
commit
f26fcdfb9e
@@ -273,6 +273,25 @@ def fp8_paged_mqa_logits_torch(
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def paged_mqa_logits_module():
|
||||||
|
paged_mqa_logits_module_path = None
|
||||||
|
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
|
||||||
|
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
|
||||||
|
elif (
|
||||||
|
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None
|
||||||
|
):
|
||||||
|
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
|
||||||
|
|
||||||
|
if paged_mqa_logits_module_path is not None:
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(paged_mqa_logits_module_path)
|
||||||
|
return module
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def rocm_fp8_paged_mqa_logits(
|
def rocm_fp8_paged_mqa_logits(
|
||||||
q_fp8: torch.Tensor,
|
q_fp8: torch.Tensor,
|
||||||
kv_cache_fp8: torch.Tensor,
|
kv_cache_fp8: torch.Tensor,
|
||||||
@@ -305,25 +324,6 @@ def rocm_fp8_paged_mqa_logits(
|
|||||||
"""
|
"""
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
|
|
||||||
@functools.lru_cache
|
|
||||||
def paged_mqa_logits_module():
|
|
||||||
paged_mqa_logits_module_path = None
|
|
||||||
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
|
|
||||||
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
|
|
||||||
elif (
|
|
||||||
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits")
|
|
||||||
is not None
|
|
||||||
):
|
|
||||||
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
|
|
||||||
|
|
||||||
if paged_mqa_logits_module_path is not None:
|
|
||||||
try:
|
|
||||||
module = importlib.import_module(paged_mqa_logits_module_path)
|
|
||||||
return module
|
|
||||||
except ImportError:
|
|
||||||
return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
aiter_paged_mqa_logits_module = None
|
aiter_paged_mqa_logits_module = None
|
||||||
if rocm_aiter_ops.is_enabled():
|
if rocm_aiter_ops.is_enabled():
|
||||||
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
|
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
|
||||||
@@ -400,6 +400,26 @@ def fp8_mqa_logits_torch(
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def mqa_logits_module():
|
||||||
|
mqa_logits_module_path = None
|
||||||
|
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
|
||||||
|
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
|
||||||
|
elif (
|
||||||
|
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
|
||||||
|
is not None
|
||||||
|
):
|
||||||
|
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
|
||||||
|
|
||||||
|
if mqa_logits_module_path is not None:
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(mqa_logits_module_path)
|
||||||
|
return module
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def rocm_fp8_mqa_logits(
|
def rocm_fp8_mqa_logits(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
kv: tuple[torch.Tensor, torch.Tensor],
|
kv: tuple[torch.Tensor, torch.Tensor],
|
||||||
@@ -429,25 +449,6 @@ def rocm_fp8_mqa_logits(
|
|||||||
# path after aiter merge this kernel into main
|
# path after aiter merge this kernel into main
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
|
|
||||||
@functools.lru_cache
|
|
||||||
def mqa_logits_module():
|
|
||||||
mqa_logits_module_path = None
|
|
||||||
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
|
|
||||||
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
|
|
||||||
elif (
|
|
||||||
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
|
|
||||||
is not None
|
|
||||||
):
|
|
||||||
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
|
|
||||||
|
|
||||||
if mqa_logits_module_path is not None:
|
|
||||||
try:
|
|
||||||
module = importlib.import_module(mqa_logits_module_path)
|
|
||||||
return module
|
|
||||||
except ImportError:
|
|
||||||
return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
aiter_mqa_logits_module = None
|
aiter_mqa_logits_module = None
|
||||||
if rocm_aiter_ops.is_enabled():
|
if rocm_aiter_ops.is_enabled():
|
||||||
aiter_mqa_logits_module = mqa_logits_module()
|
aiter_mqa_logits_module = mqa_logits_module()
|
||||||
|
|||||||
Reference in New Issue
Block a user