diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index a45ca9882..a0e5af1ab 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -138,16 +138,18 @@ def xpu_platform_plugin() -> str | None: if supports_xccl(): dist_backend = "xccl" - else: - dist_backend = "ccl" - import oneccl_bindings_for_pytorch # noqa: F401 - - if hasattr(torch, "xpu") and torch.xpu.is_available(): - is_xpu = True from vllm.platforms.xpu import XPUPlatform XPUPlatform.dist_backend = dist_backend logger.debug("Confirmed %s backend is available.", XPUPlatform.dist_backend) + else: + logger.warning( + "xccl is not enabled in this torch build, " + "communication is not available." + ) + + if hasattr(torch, "xpu") and torch.xpu.is_available(): + is_xpu = True logger.debug("Confirmed XPU platform is available.") except Exception as e: logger.debug("XPU platform is not available because: %s", str(e)) diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 611df090f..f1bdd5da3 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -7,14 +7,18 @@ import torch import torch.distributed from vllm.config import VllmConfig -from vllm.distributed import get_world_group from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.profiler.wrapper import TorchProfilerWrapper +from vllm.utils.mem_utils import MemorySnapshot, format_gib from vllm.utils.torch_utils import set_random_seed +from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment +from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.xpu_model_runner import XPUModelRunner +from .utils import request_memory + logger = init_logger(__name__) @@ -48,86 +52,6 @@ class XPUWorker(Worker): activities=["CPU", "XPU"], ) - # we provide this function due to `torch.xpu.mem_get_info()` doesn't - # return correct free_gpu_memory on intel client GPU. We need to - # calculate/estiamte it. - def xpu_get_mem_info(self): - if current_platform.is_data_center_gpu(): - return torch.xpu.mem_get_info() - else: - _, total_gpu_memory = torch.xpu.mem_get_info() - # FIXME: memory_allocated() doesn't count non-torch allocations, - # and we don't have any API to get it. so we mark it as 128MB. - used_memory = torch.xpu.memory_allocated() - non_torch_allocations = 128 * 1024 * 1024 - free_gpu_memory = total_gpu_memory - (used_memory + non_torch_allocations) - return free_gpu_memory, total_gpu_memory - - @torch.inference_mode() - def determine_available_memory(self) -> int: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - The engine will first conduct a profiling of the existing memory usage. - Then, it calculates the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. - torch.xpu.empty_cache() - torch.xpu.reset_peak_memory_stats() - - free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info() - current_allocated_bytes = torch.xpu.memory_allocated() - msg = ( - "Before memory profiling run, " - f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, " - f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, " - f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB." - ) - logger.info(msg) - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - self.model_runner.profile_run() - - free_gpu_memory, _ = self.xpu_get_mem_info() - # NOTE(woosuk): Here we assume that the other processes using the same - # GPU did not change their memory usage during the profiling. - assert self.init_gpu_memory > free_gpu_memory, ( - "Error in memory profiling. " - f"Initial free memory {self.init_gpu_memory}, current free memory" - f" {free_gpu_memory}. This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance." - ) - - # Get the peak memory allocation recorded by torch - peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"] - - torch.xpu.empty_cache() - torch_allocated_bytes = torch.xpu.memory_stats()["allocated_bytes.all.current"] - free_mem, total_mem = self.xpu_get_mem_info() - total_allocated_bytes = total_mem - free_mem - - non_torch_allocations = total_allocated_bytes - torch_allocated_bytes - if non_torch_allocations > 0: - peak_memory += non_torch_allocations - available_kv_cache_memory = ( - total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory - ) - - msg = ( - "After memory profiling run, " - f"peak memory usage is {peak_memory / 1024**2:.2f} MB," - f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, " - f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, " - f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB." - ) - logger.info(msg) - - return int(available_kv_cache_memory) - def init_device(self): device = self.device_config.device if ( @@ -161,15 +85,26 @@ class XPUWorker(Worker): current_platform.dist_backend, ) - # global all_reduce needed for overall oneccl warm up - torch.distributed.all_reduce( - torch.zeros(1).xpu(), group=get_world_group().device_group + torch.xpu.empty_cache() + self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device) + self.requested_memory = request_memory(init_snapshot, self.cache_config) + logger.debug("worker init memory snapshot: %r", self.init_snapshot) + logger.debug( + "worker requested memory: %sGiB", format_gib(self.requested_memory) ) # Set random seed. set_random_seed(self.model_config.seed) + # Initialize workspace manager + num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 + init_workspace_manager(self.device, num_ubatches) + # Construct the model runner self.model_runner = XPUModelRunner( # type: ignore self.vllm_config, self.device ) + + if self.rank == 0: + # If usage stat is enabled, collect relevant info. + report_usage_stats(self.vllm_config)