[Hardware][Intel GPU] Add Intel GPU(XPU) inference backend (#3814)
Co-authored-by: Jiang Li <jiang1.li@intel.com> Co-authored-by: Abhilash Majumder <abhilash.majumder@intel.com> Co-authored-by: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com>
This commit is contained in:
@@ -160,6 +160,26 @@ def is_tpu() -> bool:
|
||||
return libtpu is not None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def is_xpu() -> bool:
|
||||
from importlib.metadata import version
|
||||
is_xpu_flag = "xpu" in version("vllm")
|
||||
# vllm is not build with xpu
|
||||
if not is_xpu_flag:
|
||||
return False
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa: F401
|
||||
_import_ipex = True
|
||||
except ImportError as e:
|
||||
logger.warning("Import Error for IPEX: %s", e.msg)
|
||||
_import_ipex = False
|
||||
# ipex dependency is not ready
|
||||
if not _import_ipex:
|
||||
logger.warning("not found ipex lib")
|
||||
return False
|
||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||
"""Returns the maximum shared memory per thread block in bytes."""
|
||||
@@ -482,6 +502,9 @@ def is_pin_memory_available() -> bool:
|
||||
print_warning_once("Using 'pin_memory=False' as WSL is detected. "
|
||||
"This may slow down the performance.")
|
||||
return False
|
||||
elif is_xpu():
|
||||
print_warning_once("Pin memory is not supported on XPU.")
|
||||
return False
|
||||
elif is_neuron():
|
||||
print_warning_once("Pin memory is not supported on Neuron.")
|
||||
return False
|
||||
@@ -497,8 +520,12 @@ class CudaMemoryProfiler:
|
||||
|
||||
def current_memory_usage(self) -> float:
|
||||
# Return the memory usage in bytes.
|
||||
torch.cuda.reset_peak_memory_stats(self.device)
|
||||
mem = torch.cuda.max_memory_allocated(self.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats(self.device)
|
||||
mem = torch.cuda.max_memory_allocated(self.device)
|
||||
elif is_xpu():
|
||||
torch.xpu.reset_peak_memory_stats(self.device)
|
||||
mem = torch.xpu.max_memory_allocated(self.device)
|
||||
return mem
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
Reference in New Issue
Block a user