[Intel GPU] refine xpu worker (#32894)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2026-01-29 20:26:52 +08:00
committed by GitHub
parent 8b3f0a99dd
commit 8bb6271c77
2 changed files with 27 additions and 90 deletions

View File

@@ -138,16 +138,18 @@ def xpu_platform_plugin() -> str | None:
if supports_xccl():
dist_backend = "xccl"
else:
dist_backend = "ccl"
import oneccl_bindings_for_pytorch # noqa: F401
if hasattr(torch, "xpu") and torch.xpu.is_available():
is_xpu = True
from vllm.platforms.xpu import XPUPlatform
XPUPlatform.dist_backend = dist_backend
logger.debug("Confirmed %s backend is available.", XPUPlatform.dist_backend)
else:
logger.warning(
"xccl is not enabled in this torch build, "
"communication is not available."
)
if hasattr(torch, "xpu") and torch.xpu.is_available():
is_xpu = True
logger.debug("Confirmed XPU platform is available.")
except Exception as e:
logger.debug("XPU platform is not available because: %s", str(e))

View File

@@ -7,14 +7,18 @@ import torch
import torch.distributed
from vllm.config import VllmConfig
from vllm.distributed import get_world_group
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.profiler.wrapper import TorchProfilerWrapper
from vllm.utils.mem_utils import MemorySnapshot, format_gib
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
from vllm.v1.worker.workspace import init_workspace_manager
from vllm.v1.worker.xpu_model_runner import XPUModelRunner
from .utils import request_memory
logger = init_logger(__name__)
@@ -48,86 +52,6 @@ class XPUWorker(Worker):
activities=["CPU", "XPU"],
)
# we provide this function due to `torch.xpu.mem_get_info()` doesn't
# return correct free_gpu_memory on intel client GPU. We need to
# calculate/estiamte it.
def xpu_get_mem_info(self):
if current_platform.is_data_center_gpu():
return torch.xpu.mem_get_info()
else:
_, total_gpu_memory = torch.xpu.mem_get_info()
# FIXME: memory_allocated() doesn't count non-torch allocations,
# and we don't have any API to get it. so we mark it as 128MB.
used_memory = torch.xpu.memory_allocated()
non_torch_allocations = 128 * 1024 * 1024
free_gpu_memory = total_gpu_memory - (used_memory + non_torch_allocations)
return free_gpu_memory, total_gpu_memory
@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculates the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.xpu.empty_cache()
torch.xpu.reset_peak_memory_stats()
free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info()
current_allocated_bytes = torch.xpu.memory_allocated()
msg = (
"Before memory profiling run, "
f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, "
f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, "
f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB."
)
logger.info(msg)
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
free_gpu_memory, _ = self.xpu_get_mem_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
assert self.init_gpu_memory > free_gpu_memory, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance."
)
# Get the peak memory allocation recorded by torch
peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"]
torch.xpu.empty_cache()
torch_allocated_bytes = torch.xpu.memory_stats()["allocated_bytes.all.current"]
free_mem, total_mem = self.xpu_get_mem_info()
total_allocated_bytes = total_mem - free_mem
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory
)
msg = (
"After memory profiling run, "
f"peak memory usage is {peak_memory / 1024**2:.2f} MB,"
f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, "
f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, "
f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB."
)
logger.info(msg)
return int(available_kv_cache_memory)
def init_device(self):
device = self.device_config.device
if (
@@ -161,15 +85,26 @@ class XPUWorker(Worker):
current_platform.dist_backend,
)
# global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(
torch.zeros(1).xpu(), group=get_world_group().device_group
torch.xpu.empty_cache()
self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
self.requested_memory = request_memory(init_snapshot, self.cache_config)
logger.debug("worker init memory snapshot: %r", self.init_snapshot)
logger.debug(
"worker requested memory: %sGiB", format_gib(self.requested_memory)
)
# Set random seed.
set_random_seed(self.model_config.seed)
# Initialize workspace manager
num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
init_workspace_manager(self.device, num_ubatches)
# Construct the model runner
self.model_runner = XPUModelRunner( # type: ignore
self.vllm_config, self.device
)
if self.rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)