[XPU]Support CUDAGraph on XPU Platform (#34482)
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com> Co-authored-by: chzhang <chaojun.zhang@intel.com> Co-authored-by: zhenwei-intel <zhenwei.liu@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -13,6 +13,7 @@ import vllm_xpu_kernels._moe_C # noqa
|
||||
import vllm_xpu_kernels._xpu_C # noqa
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import supports_xpu_graph
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
@@ -151,10 +152,15 @@ class XPUPlatform(Platform):
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
# in V1(or with chunked prefill) block_size is 64
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 64
|
||||
@@ -166,9 +172,32 @@ class XPUPlatform(Platform):
|
||||
if compilation_config.compile_sizes is None:
|
||||
compilation_config.compile_sizes = []
|
||||
|
||||
assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, (
|
||||
"CUDA graph mode should be NONE on XPU"
|
||||
)
|
||||
attention_config = vllm_config.attention_config
|
||||
if attention_config.backend is None:
|
||||
attention_config.backend = AttentionBackendEnum.FLASH_ATTN
|
||||
if not supports_xpu_graph():
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
logger.warning(
|
||||
"XPU Graph is not supported in the current PyTorch version, "
|
||||
"disabling cudagraph_mode."
|
||||
)
|
||||
elif parallel_config.world_size_across_dp > 1:
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
logger.warning(
|
||||
"XPU Graph doesn't support capture communication ops, "
|
||||
"disabling cudagraph_mode."
|
||||
)
|
||||
else:
|
||||
if (
|
||||
attention_config.backend == AttentionBackendEnum.FLASH_ATTN
|
||||
and compilation_config.cudagraph_mode
|
||||
not in {CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE}
|
||||
):
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
logger.warning(
|
||||
"FMHA sycl-tla kernels cannot be captured with XPU graphs, "
|
||||
"falling back to PIECEWISE graph mode on XPU platform."
|
||||
)
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
@@ -201,7 +230,7 @@ class XPUPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
|
||||
@@ -745,6 +745,11 @@ def supports_xccl() -> bool:
|
||||
return torch.distributed.is_xccl_available()
|
||||
|
||||
|
||||
# Supports XPU Graph with PyTorch versions >= 2.11.0.dev for XPU platform
|
||||
def supports_xpu_graph() -> bool:
|
||||
return is_torch_equal_or_newer("2.11.0.dev")
|
||||
|
||||
|
||||
# create a library to hold the custom op
|
||||
vllm_lib = Library("vllm", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import supports_xpu_graph
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -40,6 +41,12 @@ def _torch_cuda_wrapper():
|
||||
torch.cuda.default_stream = torch.xpu.current_stream
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
torch.cuda.mem_get_info = torch.xpu.mem_get_info
|
||||
torch.cuda.synchronize = torch.xpu.synchronize
|
||||
if supports_xpu_graph():
|
||||
torch.cuda.graph = torch.xpu.graph
|
||||
torch.cuda.CUDAGraph = torch.xpu.XPUGraph
|
||||
torch.cuda.empty_cache = torch.xpu.empty_cache
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user