diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 33460222e..85d0744db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -127,6 +127,13 @@ repos: language: python types: [python] additional_dependencies: [regex] + # prevent use torch.cuda APIs + - id: check-torch-cuda-call + name: "Prevent new 'torch.cuda' APIs call" + entry: python tools/pre_commit/check_torch_cuda.py + language: python + types: [python] + additional_dependencies: [regex] - id: validate-config name: Validate configuration has default values and that each field has a docstring entry: python tools/pre_commit/validate_config.py diff --git a/benchmarks/benchmark_topk_topp.py b/benchmarks/benchmark_topk_topp.py index cac332a09..aa020e012 100644 --- a/benchmarks/benchmark_topk_topp.py +++ b/benchmarks/benchmark_topk_topp.py @@ -102,7 +102,7 @@ def reset_memory_stats(): """Reset peak memory statistics.""" reset_buffer_cache() torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() gc.collect() diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 4abeaefd7..3bd3e3f67 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -54,7 +54,7 @@ def clear_triton_cache(): # Clear CUDA memory cache if torch.cuda.is_available(): - torch.cuda.empty_cache() + torch.accelerator.empty_cache() # Try to clear Triton's runtime cache try: diff --git a/benchmarks/kernels/benchmark_reshape_and_cache.py b/benchmarks/kernels/benchmark_reshape_and_cache.py index 99067d8ac..b4c949e4f 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache.py @@ -104,7 +104,7 @@ def run_benchmark( # free tensors to mitigate OOM when sweeping del key, value, key_cache, value_cache, slot_mapping - torch.cuda.empty_cache() + torch.accelerator.empty_cache() return lat diff --git a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py index ef6be1f3c..2a250620b 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py @@ -129,7 +129,7 @@ def run_benchmark( # free tensors to mitigate OOM when sweeping del key, value, key_cache, value_cache, slot_mapping - torch.cuda.empty_cache() + torch.accelerator.empty_cache() return lat diff --git a/examples/offline_inference/lora_with_quantization_inference.py b/examples/offline_inference/lora_with_quantization_inference.py index 2f3564b59..ee5bbd82c 100644 --- a/examples/offline_inference/lora_with_quantization_inference.py +++ b/examples/offline_inference/lora_with_quantization_inference.py @@ -120,7 +120,7 @@ def main(): # Clean up the GPU memory for the next test del engine gc.collect() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() if __name__ == "__main__": diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 241aa0ad8..47dc86fa2 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -159,7 +159,7 @@ class RayTrainingActor: s.close() del buffer gc.collect() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() # Ray manages four GPUs. diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index 5c0787b87..a515917f0 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -150,7 +150,7 @@ class ColocateWorkerExtension: socket.close() del buffer gc.collect() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() def report_device_id(self) -> str: from vllm.platforms import current_platform diff --git a/tests/compile/test_dynamic_shapes_compilation.py b/tests/compile/test_dynamic_shapes_compilation.py index 6dec603a5..3dcc3c3df 100644 --- a/tests/compile/test_dynamic_shapes_compilation.py +++ b/tests/compile/test_dynamic_shapes_compilation.py @@ -99,7 +99,7 @@ def test_dynamic_shapes_compilation( # Clean up GPU memory del model gc.collect() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() torch.cuda.synchronize() print("GPU memory cleared") diff --git a/tests/conftest.py b/tests/conftest.py index 164cbeee2..413e21067 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1533,7 +1533,7 @@ def clean_gpu_memory_between_tests(): # Clean up GPU memory after the test if torch.cuda.is_available(): - torch.cuda.empty_cache() + torch.accelerator.empty_cache() gc.collect() diff --git a/tests/entrypoints/openai/test_tensorizer_entrypoint.py b/tests/entrypoints/openai/test_tensorizer_entrypoint.py index 80b7cd9f4..3cb64d50a 100644 --- a/tests/entrypoints/openai/test_tensorizer_entrypoint.py +++ b/tests/entrypoints/openai/test_tensorizer_entrypoint.py @@ -24,7 +24,7 @@ LORA_PATH = "davzoku/finqa_adapter_1b" def _cleanup(): gc.collect() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() @pytest.fixture(autouse=True) diff --git a/tests/kernels/mamba/test_causal_conv1d.py b/tests/kernels/mamba/test_causal_conv1d.py index 039f2fc06..1d10bd297 100644 --- a/tests/kernels/mamba/test_causal_conv1d.py +++ b/tests/kernels/mamba/test_causal_conv1d.py @@ -273,7 +273,7 @@ def test_causal_conv1d_varlen( batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype ): device = "cuda" - torch.cuda.empty_cache() + torch.accelerator.empty_cache() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index cda0b5c11..f8e2a8b52 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -769,7 +769,7 @@ def test_mixtral_moe( requires_grad=False, ) torch.cuda.synchronize() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() # FIXME (zyongye) fix this after we move self.kernel # assignment in FusedMoE.__init__ diff --git a/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py b/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py index ed5129e1c..610f69c8d 100644 --- a/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py +++ b/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py @@ -178,7 +178,7 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref): finally: del model gc.collect() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref): @@ -200,7 +200,7 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref) finally: del model gc.collect() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs") @@ -283,7 +283,7 @@ def test_vllm_tensorized_model_has_same_outputs( model_ref, vllm_runner, tmp_path, model_path ): gc.collect() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() config = TensorizerConfig(tensorizer_uri=str(model_path)) args = EngineArgs(model=model_ref) diff --git a/tests/test_regression.py b/tests/test_regression.py index 8a9829e4d..2fc0308ff 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -49,7 +49,7 @@ def test_gc(): del llm gc.collect() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() # The memory allocated for model and KV cache should be released. # The memory allocated for PyTorch and others should be less than 50MB. diff --git a/tests/v1/e2e/test_async_spec_decode.py b/tests/v1/e2e/test_async_spec_decode.py index 4bf76da45..726e9d89d 100644 --- a/tests/v1/e2e/test_async_spec_decode.py +++ b/tests/v1/e2e/test_async_spec_decode.py @@ -125,7 +125,7 @@ def test_no_sync_with_spec_decode( assert len(outputs[0].outputs[0].text) > 0 del llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() sync_tracker.assert_no_sync() diff --git a/tests/v1/e2e/test_lora_with_spec_decode.py b/tests/v1/e2e/test_lora_with_spec_decode.py index 8c9ab58c3..5cbdc4123 100644 --- a/tests/v1/e2e/test_lora_with_spec_decode.py +++ b/tests/v1/e2e/test_lora_with_spec_decode.py @@ -95,7 +95,7 @@ def test_batch_inference_correctness( prompts, sampling_params, lora_request=lora_request ) del ref_llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() lora_spec_llm = LLM( @@ -135,5 +135,5 @@ def test_batch_inference_correctness( print(f"match ratio: {matches}/{len(ref_outputs)}") assert matches > int(0.90 * len(ref_outputs)) del lora_spec_llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 5aa72ccb3..3ba7651c3 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -440,7 +440,7 @@ def _run_ref_mamba_state_worker(): torch.save(cpu_state_ref, "mamba_kv_cache_dict_ref.pth") mamba_kv_cache_dict.clear() del engine - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() except Exception: traceback.print_exc() @@ -805,5 +805,5 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict, keys_to_check) mamba_kv_cache_dict.clear() del engine - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 4c90df5f4..4066dfe9e 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -179,7 +179,7 @@ def test_ngram_and_suffix_correctness( ) evaluate_llm_for_gsm8k(spec_llm) del spec_llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() @@ -240,7 +240,7 @@ def test_suffix_decoding_acceptance( assert last_accept_rate > 0.80 del spec_llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() @@ -307,14 +307,14 @@ def test_speculators_model_integration( verifier_model = spec_llm.llm_engine.vllm_config.model_config.model del spec_llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() # Second run: Reference without speculative decoding ref_llm = LLM(model=verifier_model, max_model_len=4096) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() # Compare outputs @@ -410,7 +410,7 @@ def _run_eagle_correctness( ) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() spec_llm = LLM( @@ -445,7 +445,7 @@ def _run_eagle_correctness( assert matches > int(0.6 * len(ref_outputs)) del spec_llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() @@ -715,7 +715,7 @@ def test_mtp_correctness( ref_llm, expected_accuracy_threshold=expected_accuracy_threshold ) del ref_llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() spec_llm = LLM( @@ -747,7 +747,7 @@ def test_mtp_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs)) del spec_llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() @@ -952,7 +952,7 @@ def assert_draft_model_correctness(args: ArgsTest): ) del spec_llm # CLEANUP - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() print( diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index c6c9c0ce4..aa084eee8 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -857,7 +857,7 @@ def test_structured_output_batched_with_non_structured_outputs_requests( # Free memory as soon as possible as failed assertions # will short circuit and not free up memory del llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() for index, output in enumerate(outputs): diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 8a384dd84..3a83f835c 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -530,7 +530,7 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode): assert positive_values > 0 finally: del llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() @@ -1065,7 +1065,7 @@ def test_spec_decode_logprobs( for logprobs in output.logprobs: ref_logprobs.extend(logprobs.values()) del ref_llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() # Run spec decode LLM. @@ -1095,7 +1095,7 @@ def test_spec_decode_logprobs( for logprobs in output.logprobs: spec_logprobs.extend(logprobs.values()) del spec_llm - torch.cuda.empty_cache() + torch.accelerator.empty_cache() cleanup_dist_env_and_memory() # Per-token logprobs are expected to be the same. diff --git a/tools/pre_commit/check_torch_cuda.py b/tools/pre_commit/check_torch_cuda.py new file mode 100644 index 000000000..f2e3cbf26 --- /dev/null +++ b/tools/pre_commit/check_torch_cuda.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import sys + +import regex as re + +# --------------------------------------------------------------------------- # +# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx` +# --------------------------------------------------------------------------- # +_TORCH_CUDA_PATTERNS = [ + r"\btorch\.cuda\.empty_cache\b", +] + +ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"} + + +def scan_file(path: str) -> int: + with open(path, encoding="utf-8") as f: + content = f.read() + for pattern in _TORCH_CUDA_PATTERNS: + for match in re.finditer(pattern, content, re.MULTILINE): + # Calculate line number from match position + line_num = content[: match.start() + 1].count("\n") + 1 + print( + f"{path}:{line_num}: " + "\033[91merror:\033[0m " # red color + "Found torch.cuda API call" + ) + return 1 + return 0 + + +def main(): + returncode = 0 + for filename in sys.argv[1:]: + if any(filename.startswith(prefix) for prefix in ALLOWED_FILES): + continue + returncode |= scan_file(filename) + return returncode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 7bada5e7c..41db70155 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -260,7 +260,9 @@ class CUDAGraphWrapper: # therefore, we only run gc for the first graph, # and disable gc for the rest of the graphs. stack.enter_context(patch("gc.collect", lambda: None)) - stack.enter_context(patch("torch.cuda.empty_cache", lambda: None)) + stack.enter_context( + patch("torch.accelerator.empty_cache", lambda: None) + ) if self.graph_pool is not None: set_graph_pool_id(self.graph_pool) diff --git a/vllm/distributed/elastic_ep/elastic_execute.py b/vllm/distributed/elastic_ep/elastic_execute.py index 22d570660..f32ea39fb 100644 --- a/vllm/distributed/elastic_ep/elastic_execute.py +++ b/vllm/distributed/elastic_ep/elastic_execute.py @@ -408,7 +408,7 @@ class ElasticEPScalingExecutor: gc.collect() torch.cuda.synchronize() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() unlock_workspace() self.worker.compile_or_warm_up_model() lock_workspace() diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index fc554bd75..d0a67cf84 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1916,14 +1916,14 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): gc.collect() from vllm.platforms import current_platform - empty_cache = current_platform.empty_cache - if empty_cache is not None: - empty_cache() - try: - if not current_platform.is_cpu(): + if not current_platform.is_cpu(): + torch.accelerator.empty_cache() + try: torch._C._host_emptyCache() - except AttributeError: - logger.warning("torch._C._host_emptyCache() only available in Pytorch >=2.5") + except AttributeError: + logger.warning( + "torch._C._host_emptyCache() only available in Pytorch >=2.5" + ) def in_the_same_node_as( diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 95b6f7b77..a29d8a7d8 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -200,7 +200,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ): num_pad = 256 // weight.element_size() weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] - torch.cuda.empty_cache() + torch.accelerator.empty_cache() return weight diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index b7cb84e8f..0a5db4e71 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -961,7 +961,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): # secondly, process mxfp weights if self.emulate: - torch.cuda.empty_cache() + torch.accelerator.empty_cache() return from aiter.utility.fp4_utils import e8m0_shuffle @@ -995,7 +995,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) layer.w13_weight.is_shuffled = True layer.w2_weight.is_shuffled = True - torch.cuda.empty_cache() + torch.accelerator.empty_cache() def get_fused_moe_quant_config( self, layer: torch.nn.Module @@ -1116,7 +1116,7 @@ class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod): del layer.w2_weight layer.w13_weight = None layer.w2_weight = None - torch.cuda.empty_cache() + torch.accelerator.empty_cache() if self.static_input_scales: if layer.w13_input_scale is None or layer.w2_input_scale is None: diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ee3f2ce96..41d44e0c4 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -1407,7 +1407,7 @@ def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: import torch.nn.functional as F weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] - torch.cuda.empty_cache() + torch.accelerator.empty_cache() return weight diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 40b33cdc5..81526415f 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -811,7 +811,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): **stacked_quant_state_dict, } self._bind_quant_states_to_params(model, stacked_quant_state_dict) - torch.cuda.empty_cache() + torch.accelerator.empty_cache() def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) diff --git a/vllm/utils/mem_utils.py b/vllm/utils/mem_utils.py index 0b3971126..30e38b0bf 100644 --- a/vllm/utils/mem_utils.py +++ b/vllm/utils/mem_utils.py @@ -96,7 +96,7 @@ class MemorySnapshot: # rather than `torch.cuda.memory_reserved()` . # After `torch.cuda.reset_peak_memory_stats()`, # `torch.cuda.memory_reserved()` will keep growing, and only shrink - # when we call `torch.cuda.empty_cache()` or OOM happens. + # when we call `torch.accelerator.empty_cache()` or OOM happens. self.torch_peak = current_platform.memory_stats(device).get( "allocated_bytes.all.peak", 0 ) @@ -250,7 +250,7 @@ def memory_profiling( until after profiling to get (c.). """ gc.collect() - current_platform.empty_cache() + torch.accelerator.empty_cache() current_platform.reset_peak_memory_stats(baseline_snapshot.device_) result = MemoryProfilingResult( @@ -264,7 +264,7 @@ def memory_profiling( yield result gc.collect() - current_platform.empty_cache() + torch.accelerator.empty_cache() result.after_profile.measure() diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index 114936129..050165ea5 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -1036,4 +1036,4 @@ def apply_top_k_top_p_triton( def reset_buffer_cache(): _TRITON_BUFFER_CACHE.clear() _TRITON_TABLE_CACHE.clear() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 9267e1874..203d31195 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -496,7 +496,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): start_time = time.perf_counter() gc.collect() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() start_free_gpu_memory = torch.cuda.mem_get_info()[0] with self.maybe_setup_dummy_loras(self.lora_config): diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index c0654abd5..4c11aede5 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -278,7 +278,7 @@ class Worker(WorkerBase): # Now take memory snapshot after NCCL is initialized gc.collect() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() # take current memory snapshot self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device) @@ -585,7 +585,7 @@ class Worker(WorkerBase): # sampling related tensors of max possible shape to avoid memory # fragmentation issue. # NOTE: This is called after `capture_model` on purpose to prevent - # memory buffers from being cleared by `torch.cuda.empty_cache`. + # memory buffers from being cleared by `torch.accelerator.empty_cache`. max_num_reqs = min( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py index 8ca35b4c3..540c9cb20 100644 --- a/vllm/v1/worker/xpu_model_runner.py +++ b/vllm/v1/worker/xpu_model_runner.py @@ -46,7 +46,6 @@ def _torch_cuda_wrapper(): if supports_xpu_graph(): torch.cuda.graph = torch.xpu.graph torch.cuda.CUDAGraph = torch.xpu.XPUGraph - torch.cuda.empty_cache = torch.xpu.empty_cache yield finally: pass diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 6e45a107c..24fc65066 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -62,7 +62,7 @@ class XPUWorker(Worker): self.device = torch.device(f"xpu:{self.local_rank}") current_platform.set_device(self.device) current_platform.check_if_supports_dtype(self.model_config.dtype) - torch.xpu.empty_cache() + torch.accelerator.empty_cache() self.init_gpu_memory = torch.xpu.get_device_properties( self.local_rank ).total_memory @@ -90,7 +90,7 @@ class XPUWorker(Worker): # Now take memory snapshot after NCCL is initialized gc.collect() - torch.xpu.empty_cache() + torch.accelerator.empty_cache() # take current memory snapshot self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)