[Platforms] Refactor xpu code (#10468)

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2024-11-20 14:52:13 +08:00
committed by GitHub
parent 09dbf9ff16
commit d5b28447e0
2 changed files with 21 additions and 27 deletions

View File

@@ -1,9 +1,16 @@
from typing import TYPE_CHECKING
import torch
from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
logger = init_logger(__name__)
@@ -34,3 +41,17 @@ class XPUPlatform(Platform):
@staticmethod
def inference_mode():
return torch.no_grad()
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# check and update model config
model_config = vllm_config.model_config
if model_config.dtype == torch.bfloat16:
logger.warning(
"bfloat16 is not fully supported on XPU, casting to float16.")
model_config.dtype = torch.float16
if not model_config.enforce_eager:
logger.warning(
"CUDA graph is not supported on XPU, fallback to the eager "
"mode.")
model_config.enforce_eager = True