From 6edd43de3ce2aa9ca93b8ece656af7547526afd3 Mon Sep 17 00:00:00 2001 From: JartX Date: Tue, 17 Mar 2026 22:55:34 +0100 Subject: [PATCH] [Bugfix][ROCm] Fix worker startup OOM on ROCm by skipping unreliable cudagraph memory profiling (#36720) Signed-off-by: JartX (cherry picked from commit e8f9dbc369aa2086ec1e1fe3b104c582812cfc17) --- vllm/v1/worker/gpu_worker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 58e2d658c..6d117175b 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -392,8 +392,10 @@ class Worker(WorkerBase): ) # Profile CUDA graph memory if graphs will be captured. + # Skip on ROCm/HIP as graph pool handles and mem_get_info behave + # differently and can produce incorrect/negative estimates. cudagraph_memory_estimate = 0 - if not self.model_config.enforce_eager: + if not self.model_config.enforce_eager and not current_platform.is_rocm(): cudagraph_memory_estimate = self.model_runner.profile_cudagraph_memory() # Use the pre-cudagraph torch peak to avoid double-counting. @@ -406,6 +408,8 @@ class Worker(WorkerBase): + profile_result.weights_memory ) + # On ROCm, cudagraph_memory_estimate is always 0 so this is a no-op. + # On CUDA, respect the opt-in flag as originally designed. cudagraph_memory_estimate_applied = ( cudagraph_memory_estimate if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS @@ -517,7 +521,6 @@ class Worker(WorkerBase): def update_max_model_len(self, max_model_len: int) -> None: """Update max_model_len after auto-fit to GPU memory. - This is called when max_model_len=-1 is used and the engine automatically determines the maximum context length that fits in GPU memory. Workers need to update their cached max_model_len