diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 2f015339e..e36a90d6c 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -840,6 +840,24 @@ class SamplingParams( f"extra_args={self.extra_args})" ) + @staticmethod + def for_sampler_warmup() -> "SamplingParams": + """Set parameters to exercise all sampler logic.""" + return SamplingParams( + temperature=0.9, + top_p=0.9, + top_k=50, + min_p=0.1, + frequency_penalty=0.5, + presence_penalty=0.5, + repetition_penalty=1.2, + min_tokens=2, + logit_bias={0: -1.0, 1: 0.5}, + _bad_words_token_ids=[[0], [1, 2]], + logprobs=5, + prompt_logprobs=1, + ) + class BeamSearchParams( msgspec.Struct, diff --git a/vllm/v1/worker/gpu/warmup.py b/vllm/v1/worker/gpu/warmup.py new file mode 100644 index 000000000..ffe5b33f7 --- /dev/null +++ b/vllm/v1/worker/gpu/warmup.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import numpy as np +import torch + +from vllm import PoolingParams, SamplingParams +from vllm.v1.core.sched.output import ( + CachedRequestData, + GrammarOutput, + NewRequestData, + SchedulerOutput, +) +from vllm.v1.request import Request +from vllm.v1.worker.gpu.model_runner import GPUModelRunner + + +@torch.inference_mode() +def warmup_kernels(model_runner: GPUModelRunner) -> None: + """Run two execute_model + sample_tokens iterations to JIT compile + triton kernels. + + The first iteration simulates a prefill with requests of 2 prompt + tokens each. The second iteration simulates a decode step with all + requests generating 1 token each. + """ + prompt_token_ids = [0, 1] + prompt_len = len(prompt_token_ids) + num_reqs = min( + model_runner.scheduler_config.max_num_seqs, + model_runner.scheduler_config.max_num_batched_tokens // prompt_len, + ) + + num_kv_cache_groups = len(model_runner.kv_cache_config.kv_cache_groups) + req_ids = [f"_warmup_{i}_" for i in range(num_reqs)] + + # SamplingParams exercising all sampling features. + if model_runner.is_pooling_model: + sampling_params = None + pooling_params = PoolingParams() + else: + sampling_params = SamplingParams.for_sampler_warmup() + pooling_params = None + + # Step 1: Prefill all requests with 2 prompt tokens each. + new_reqs = [ + NewRequestData.from_request( + Request(req_ids[i], prompt_token_ids, sampling_params, pooling_params), + # Each request uses a distinct block per KV cache group. + block_ids=tuple([i] for _ in range(num_kv_cache_groups)), + prefill_token_ids=prompt_token_ids, + ) + for i in range(num_reqs) + ] + + prefill_output = SchedulerOutput.make_empty() + prefill_output.scheduled_new_reqs = new_reqs + prefill_output.num_scheduled_tokens = {rid: prompt_len for rid in req_ids} + prefill_output.total_num_scheduled_tokens = prompt_len * num_reqs + prefill_output.num_common_prefix_blocks = [0] * num_kv_cache_groups + + # Disable KV connector for warmup run. + model_runner.kv_connector.set_disabled(True) + model_runner.execute_model(prefill_output) + + if not model_runner.is_pooling_model: + # Warm up sampler and perform a decode step for non-pooling models. + + grammar_output = None + if model_runner.is_last_pp_rank: + # Build a GrammarOutput to exercise the structured output bitmask + # kernel during the prefill step. + vocab_size = model_runner.model_config.get_vocab_size() + bitmask_width = (vocab_size + 31) // 32 + grammar_bitmask = np.full( + (len(req_ids), bitmask_width), fill_value=-1, dtype=np.int32 + ) + grammar_output = GrammarOutput( + structured_output_request_ids=req_ids, grammar_bitmask=grammar_bitmask + ) + + model_runner.sample_tokens(grammar_output) + + # Step 2: Decode all requests with 1 token each. + cached_req_data = CachedRequestData.make_empty() + cached_req_data.req_ids = list(req_ids) + cached_req_data.new_block_ids = [None] * num_reqs + cached_req_data.num_computed_tokens = [prompt_len] * num_reqs + cached_req_data.num_output_tokens = [1] * num_reqs + + decode_output = SchedulerOutput.make_empty() + decode_output.scheduled_cached_reqs = cached_req_data + decode_output.num_scheduled_tokens = {rid: 1 for rid in req_ids} + decode_output.total_num_scheduled_tokens = num_reqs + decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups + + model_runner.execute_model(decode_output) + model_runner.sample_tokens(None) + + # Clean up - process finish_req_ids. + cleanup_output = SchedulerOutput.make_empty() + cleanup_output.finished_req_ids = set(req_ids) + model_runner.execute_model(cleanup_output) + model_runner.kv_connector.set_disabled(False) + torch.cuda.synchronize() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 3aeb20839..fcc0fdf88 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -61,6 +61,7 @@ from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.workspace import init_workspace_manager +from .gpu.warmup import warmup_kernels from .utils import request_memory logger = init_logger(__name__) @@ -558,12 +559,15 @@ class Worker(WorkerBase): logger.debug(msg) - # Warm up sampler and preallocate memory buffer for logits and other - # sampling related tensors of max possible shape to avoid memory - # fragmentation issue. - # NOTE: This is called after `capture_model` on purpose to prevent - # memory buffers from being cleared by `torch.cuda.empty_cache`. - if get_pp_group().is_last_rank: + if self.use_v2_model_runner: + # V2: Run full execute_model + sample_tokens to JIT compile triton kernels. + warmup_kernels(self.model_runner) + elif get_pp_group().is_last_rank: + # V1: Warm up sampler and preallocate memory buffer for logits and other + # sampling related tensors of max possible shape to avoid memory + # fragmentation issue. + # NOTE: This is called after `capture_model` on purpose to prevent + # memory buffers from being cleared by `torch.cuda.empty_cache`. max_num_reqs = min( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens,