[XPU]Support CUDAGraph on XPU Platform (#34482)

Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: chzhang <chaojun.zhang@intel.com>
Co-authored-by: zhenwei-intel <zhenwei.liu@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Xinyu Chen
2026-02-25 14:22:52 +08:00
committed by GitHub
parent 8ad54a991b
commit 35d44b4557
3 changed files with 45 additions and 4 deletions

View File

@@ -13,6 +13,7 @@ import vllm_xpu_kernels._moe_C # noqa
import vllm_xpu_kernels._xpu_C # noqa
from vllm.logger import init_logger
from vllm.utils.torch_utils import supports_xpu_graph
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import DeviceCapability, Platform, PlatformEnum
@@ -151,10 +152,15 @@ class XPUPlatform(Platform):
def inference_mode(cls):
return torch.no_grad()
@classmethod
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
# in V1(or with chunked prefill) block_size is 64
if cache_config and cache_config.block_size is None:
cache_config.block_size = 64
@@ -166,9 +172,32 @@ class XPUPlatform(Platform):
if compilation_config.compile_sizes is None:
compilation_config.compile_sizes = []
assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, (
"CUDA graph mode should be NONE on XPU"
)
attention_config = vllm_config.attention_config
if attention_config.backend is None:
attention_config.backend = AttentionBackendEnum.FLASH_ATTN
if not supports_xpu_graph():
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
logger.warning(
"XPU Graph is not supported in the current PyTorch version, "
"disabling cudagraph_mode."
)
elif parallel_config.world_size_across_dp > 1:
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
logger.warning(
"XPU Graph doesn't support capture communication ops, "
"disabling cudagraph_mode."
)
else:
if (
attention_config.backend == AttentionBackendEnum.FLASH_ATTN
and compilation_config.cudagraph_mode
not in {CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE}
):
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
logger.warning(
"FMHA sycl-tla kernels cannot be captured with XPU graphs, "
"falling back to PIECEWISE graph mode on XPU platform."
)
if vllm_config.lora_config is not None:
compilation_config.mode = CompilationMode.NONE
@@ -201,7 +230,7 @@ class XPUPlatform(Platform):
@classmethod
def support_static_graph_mode(cls) -> bool:
return False
return True
@classmethod
def is_pin_memory_available(cls):

View File

@@ -745,6 +745,11 @@ def supports_xccl() -> bool:
return torch.distributed.is_xccl_available()
# Supports XPU Graph with PyTorch versions >= 2.11.0.dev for XPU platform
def supports_xpu_graph() -> bool:
return is_torch_equal_or_newer("2.11.0.dev")
# create a library to hold the custom op
vllm_lib = Library("vllm", "FRAGMENT") # noqa

View File

@@ -7,6 +7,7 @@ import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.torch_utils import supports_xpu_graph
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if TYPE_CHECKING:
@@ -40,6 +41,12 @@ def _torch_cuda_wrapper():
torch.cuda.default_stream = torch.xpu.current_stream
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.stream = torch.xpu.stream
torch.cuda.mem_get_info = torch.xpu.mem_get_info
torch.cuda.synchronize = torch.xpu.synchronize
if supports_xpu_graph():
torch.cuda.graph = torch.xpu.graph
torch.cuda.CUDAGraph = torch.xpu.XPUGraph
torch.cuda.empty_cache = torch.xpu.empty_cache
yield
finally:
pass