[Model Runner V2] Fix warmup for pipeline parallel (#36280)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -1,6 +1,9 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -17,9 +20,14 @@ from vllm.v1.worker.gpu.model_runner import GPUModelRunner
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def warmup_kernels(model_runner: GPUModelRunner) -> None:
|
def warmup_kernels(
|
||||||
|
model_runner: GPUModelRunner,
|
||||||
|
worker_execute_model: Callable[[SchedulerOutput], Any],
|
||||||
|
worker_sample_tokens: Callable[[GrammarOutput | None], Any],
|
||||||
|
) -> None:
|
||||||
"""Run two execute_model + sample_tokens iterations to JIT compile
|
"""Run two execute_model + sample_tokens iterations to JIT compile
|
||||||
triton kernels.
|
triton kernels. We must call the provided worker's execute_model for
|
||||||
|
pipeline parallel coordination.
|
||||||
|
|
||||||
The first iteration simulates a prefill with requests of 2 prompt
|
The first iteration simulates a prefill with requests of 2 prompt
|
||||||
tokens each. The second iteration simulates a decode step with all
|
tokens each. The second iteration simulates a decode step with all
|
||||||
@@ -83,7 +91,7 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None:
|
|||||||
|
|
||||||
# Disable KV connector for warmup run.
|
# Disable KV connector for warmup run.
|
||||||
model_runner.kv_connector.set_disabled(True)
|
model_runner.kv_connector.set_disabled(True)
|
||||||
model_runner.execute_model(prefill_output)
|
worker_execute_model(prefill_output)
|
||||||
|
|
||||||
if not model_runner.is_pooling_model:
|
if not model_runner.is_pooling_model:
|
||||||
# Warm up sampler and perform a decode step for non-pooling models.
|
# Warm up sampler and perform a decode step for non-pooling models.
|
||||||
@@ -101,7 +109,7 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None:
|
|||||||
structured_output_request_ids=req_ids, grammar_bitmask=grammar_bitmask
|
structured_output_request_ids=req_ids, grammar_bitmask=grammar_bitmask
|
||||||
)
|
)
|
||||||
|
|
||||||
model_runner.sample_tokens(grammar_output)
|
worker_sample_tokens(grammar_output)
|
||||||
|
|
||||||
# Step 2: Decode all requests with 1 token each.
|
# Step 2: Decode all requests with 1 token each.
|
||||||
cached_req_data = CachedRequestData.make_empty()
|
cached_req_data = CachedRequestData.make_empty()
|
||||||
@@ -120,12 +128,12 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None:
|
|||||||
decode_output.total_num_scheduled_tokens = num_reqs
|
decode_output.total_num_scheduled_tokens = num_reqs
|
||||||
decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
|
decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
|
||||||
|
|
||||||
model_runner.execute_model(decode_output)
|
worker_execute_model(decode_output)
|
||||||
model_runner.sample_tokens(None)
|
worker_sample_tokens(None)
|
||||||
|
|
||||||
# Clean up - process finish_req_ids.
|
# Clean up - process finish_req_ids.
|
||||||
cleanup_output = SchedulerOutput.make_empty()
|
cleanup_output = SchedulerOutput.make_empty()
|
||||||
cleanup_output.finished_req_ids = set(req_ids)
|
cleanup_output.finished_req_ids = set(req_ids)
|
||||||
model_runner.execute_model(cleanup_output)
|
worker_execute_model(cleanup_output)
|
||||||
model_runner.kv_connector.set_disabled(False)
|
model_runner.kv_connector.set_disabled(False)
|
||||||
torch.accelerator.synchronize()
|
torch.accelerator.synchronize()
|
||||||
|
|||||||
@@ -584,7 +584,7 @@ class Worker(WorkerBase):
|
|||||||
|
|
||||||
if self.use_v2_model_runner:
|
if self.use_v2_model_runner:
|
||||||
# V2: Run full execute_model + sample_tokens to JIT compile triton kernels.
|
# V2: Run full execute_model + sample_tokens to JIT compile triton kernels.
|
||||||
warmup_kernels(self.model_runner)
|
warmup_kernels(self.model_runner, self.execute_model, self.sample_tokens)
|
||||||
elif get_pp_group().is_last_rank:
|
elif get_pp_group().is_last_rank:
|
||||||
# V1: Warm up sampler and preallocate memory buffer for logits and other
|
# V1: Warm up sampler and preallocate memory buffer for logits and other
|
||||||
# sampling related tensors of max possible shape to avoid memory
|
# sampling related tensors of max possible shape to avoid memory
|
||||||
|
|||||||
Reference in New Issue
Block a user