[Bugfix][ROCm] Fix worker startup OOM on ROCm by skipping unreliable cudagraph memory profiling (#36720)

Signed-off-by: JartX <sagformas@epdcenter.es>
This commit is contained in:
JartX
2026-03-17 22:55:34 +01:00
committed by GitHub
parent de35c06c66
commit e8f9dbc369

View File

@@ -392,8 +392,10 @@ class Worker(WorkerBase):
)
# Profile CUDA graph memory if graphs will be captured.
# Skip on ROCm/HIP as graph pool handles and mem_get_info behave
# differently and can produce incorrect/negative estimates.
cudagraph_memory_estimate = 0
if not self.model_config.enforce_eager:
if not self.model_config.enforce_eager and not current_platform.is_rocm():
cudagraph_memory_estimate = self.model_runner.profile_cudagraph_memory()
# Use the pre-cudagraph torch peak to avoid double-counting.
@@ -406,6 +408,8 @@ class Worker(WorkerBase):
+ profile_result.weights_memory
)
# On ROCm, cudagraph_memory_estimate is always 0 so this is a no-op.
# On CUDA, respect the opt-in flag as originally designed.
cudagraph_memory_estimate_applied = (
cudagraph_memory_estimate
if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS
@@ -517,7 +521,6 @@ class Worker(WorkerBase):
def update_max_model_len(self, max_model_len: int) -> None:
"""Update max_model_len after auto-fit to GPU memory.
This is called when max_model_len=-1 is used and the engine
automatically determines the maximum context length that fits
in GPU memory. Workers need to update their cached max_model_len