diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index caa4305a5..454d2301e 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -13,6 +13,7 @@ import vllm_xpu_kernels._moe_C # noqa import vllm_xpu_kernels._xpu_C # noqa from vllm.logger import init_logger +from vllm.utils.torch_utils import supports_xpu_graph from vllm.v1.attention.backends.registry import AttentionBackendEnum from .interface import DeviceCapability, Platform, PlatformEnum @@ -151,10 +152,15 @@ class XPUPlatform(Platform): def inference_mode(cls): return torch.no_grad() + @classmethod + def get_static_graph_wrapper_cls(cls) -> str: + return "vllm.compilation.cuda_graph.CUDAGraphWrapper" + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config # in V1(or with chunked prefill) block_size is 64 if cache_config and cache_config.block_size is None: cache_config.block_size = 64 @@ -166,9 +172,32 @@ class XPUPlatform(Platform): if compilation_config.compile_sizes is None: compilation_config.compile_sizes = [] - assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, ( - "CUDA graph mode should be NONE on XPU" - ) + attention_config = vllm_config.attention_config + if attention_config.backend is None: + attention_config.backend = AttentionBackendEnum.FLASH_ATTN + if not supports_xpu_graph(): + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + logger.warning( + "XPU Graph is not supported in the current PyTorch version, " + "disabling cudagraph_mode." + ) + elif parallel_config.world_size_across_dp > 1: + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + logger.warning( + "XPU Graph doesn't support capture communication ops, " + "disabling cudagraph_mode." + ) + else: + if ( + attention_config.backend == AttentionBackendEnum.FLASH_ATTN + and compilation_config.cudagraph_mode + not in {CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE} + ): + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + logger.warning( + "FMHA sycl-tla kernels cannot be captured with XPU graphs, " + "falling back to PIECEWISE graph mode on XPU platform." + ) if vllm_config.lora_config is not None: compilation_config.mode = CompilationMode.NONE @@ -201,7 +230,7 @@ class XPUPlatform(Platform): @classmethod def support_static_graph_mode(cls) -> bool: - return False + return True @classmethod def is_pin_memory_available(cls): diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 17a0ddd6d..e834108ca 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -745,6 +745,11 @@ def supports_xccl() -> bool: return torch.distributed.is_xccl_available() +# Supports XPU Graph with PyTorch versions >= 2.11.0.dev for XPU platform +def supports_xpu_graph() -> bool: + return is_torch_equal_or_newer("2.11.0.dev") + + # create a library to hold the custom op vllm_lib = Library("vllm", "FRAGMENT") # noqa diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py index e2cd49990..8ca35b4c3 100644 --- a/vllm/v1/worker/xpu_model_runner.py +++ b/vllm/v1/worker/xpu_model_runner.py @@ -7,6 +7,7 @@ import torch from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.utils.torch_utils import supports_xpu_graph from vllm.v1.worker.gpu_model_runner import GPUModelRunner if TYPE_CHECKING: @@ -40,6 +41,12 @@ def _torch_cuda_wrapper(): torch.cuda.default_stream = torch.xpu.current_stream torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.stream = torch.xpu.stream + torch.cuda.mem_get_info = torch.xpu.mem_get_info + torch.cuda.synchronize = torch.xpu.synchronize + if supports_xpu_graph(): + torch.cuda.graph = torch.xpu.graph + torch.cuda.CUDAGraph = torch.xpu.XPUGraph + torch.cuda.empty_cache = torch.xpu.empty_cache yield finally: pass