[Platforms] Refactor xpu code (#10468)
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -1,9 +1,16 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -34,3 +41,17 @@ class XPUPlatform(Platform):
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
# check and update model config
|
||||
model_config = vllm_config.model_config
|
||||
if model_config.dtype == torch.bfloat16:
|
||||
logger.warning(
|
||||
"bfloat16 is not fully supported on XPU, casting to float16.")
|
||||
model_config.dtype = torch.float16
|
||||
if not model_config.enforce_eager:
|
||||
logger.warning(
|
||||
"CUDA graph is not supported on XPU, fallback to the eager "
|
||||
"mode.")
|
||||
model_config.enforce_eager = True
|
||||
|
||||
Reference in New Issue
Block a user