[Bugfix][plugin] fla crash on plugin (#27322)
This commit is contained in:
@@ -17,6 +17,7 @@ from typing import Any, Literal
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import triton
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -137,8 +138,8 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
|
|||||||
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
||||||
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
||||||
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
||||||
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
device = "cuda" if current_platform.is_cuda_alike() else get_available_device()
|
||||||
device_torch_lib = getattr(torch, device)
|
device_torch_lib = getattr(torch, device, None)
|
||||||
device_platform = _check_platform()
|
device_platform = _check_platform()
|
||||||
|
|
||||||
is_amd = device_platform == "amd"
|
is_amd = device_platform == "amd"
|
||||||
|
|||||||
Reference in New Issue
Block a user