diff --git a/vllm/v1/worker/gpu/warmup.py b/vllm/v1/worker/gpu/warmup.py index 082b4e642..28e480134 100644 --- a/vllm/v1/worker/gpu/warmup.py +++ b/vllm/v1/worker/gpu/warmup.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable +from typing import Any + import numpy as np import torch @@ -17,9 +20,14 @@ from vllm.v1.worker.gpu.model_runner import GPUModelRunner @torch.inference_mode() -def warmup_kernels(model_runner: GPUModelRunner) -> None: +def warmup_kernels( + model_runner: GPUModelRunner, + worker_execute_model: Callable[[SchedulerOutput], Any], + worker_sample_tokens: Callable[[GrammarOutput | None], Any], +) -> None: """Run two execute_model + sample_tokens iterations to JIT compile - triton kernels. + triton kernels. We must call the provided worker's execute_model for + pipeline parallel coordination. The first iteration simulates a prefill with requests of 2 prompt tokens each. The second iteration simulates a decode step with all @@ -83,7 +91,7 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None: # Disable KV connector for warmup run. model_runner.kv_connector.set_disabled(True) - model_runner.execute_model(prefill_output) + worker_execute_model(prefill_output) if not model_runner.is_pooling_model: # Warm up sampler and perform a decode step for non-pooling models. @@ -101,7 +109,7 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None: structured_output_request_ids=req_ids, grammar_bitmask=grammar_bitmask ) - model_runner.sample_tokens(grammar_output) + worker_sample_tokens(grammar_output) # Step 2: Decode all requests with 1 token each. cached_req_data = CachedRequestData.make_empty() @@ -120,12 +128,12 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None: decode_output.total_num_scheduled_tokens = num_reqs decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups - model_runner.execute_model(decode_output) - model_runner.sample_tokens(None) + worker_execute_model(decode_output) + worker_sample_tokens(None) # Clean up - process finish_req_ids. cleanup_output = SchedulerOutput.make_empty() cleanup_output.finished_req_ids = set(req_ids) - model_runner.execute_model(cleanup_output) + worker_execute_model(cleanup_output) model_runner.kv_connector.set_disabled(False) torch.accelerator.synchronize() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 99efe6057..e56905fe7 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -584,7 +584,7 @@ class Worker(WorkerBase): if self.use_v2_model_runner: # V2: Run full execute_model + sample_tokens to JIT compile triton kernels. - warmup_kernels(self.model_runner) + warmup_kernels(self.model_runner, self.execute_model, self.sample_tokens) elif get_pp_group().is_last_rank: # V1: Warm up sampler and preallocate memory buffer for logits and other # sampling related tensors of max possible shape to avoid memory