diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4117354dd..17dfcae59 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3058,131 +3058,129 @@ class GPUModelRunner( scheduler_output = deepcopy(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - with record_function_or_nullcontext("gpu_model_runner: preprocess"): - with self.synchronize_input_prep(): - # Update persistent batch states. - self._update_states(scheduler_output) + with ( + record_function_or_nullcontext("gpu_model_runner: preprocess"), + self.synchronize_input_prep(), + ): + # Update persistent batch states. + self._update_states(scheduler_output) - if has_ec_transfer() and get_ec_transfer().is_producer: - with self.maybe_get_ec_connector_output( - scheduler_output, - encoder_cache=self.encoder_cache, - ) as ec_connector_output: - self._execute_mm_encoder(scheduler_output) - return make_empty_encoder_model_runner_output(scheduler_output) - - if not num_scheduled_tokens: - if ( - self.parallel_config.distributed_executor_backend - == "external_launcher" - and self.parallel_config.data_parallel_size > 1 - ): - # this is a corner case when both external launcher - # and DP are enabled, num_scheduled_tokens could be - # 0, and has_unfinished_requests in the outer loop - # returns True. before returning early here we call - # dummy run to ensure coordinate_batch_across_dp - # is called into to avoid out of sync issues. - self._dummy_run(1) - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - return self.kv_connector_no_forward( - scheduler_output, self.vllm_config - ) - if self.cache_config.kv_sharing_fast_prefill: - assert not self.num_prompt_logprobs, ( - "--kv-sharing-fast-prefill produces incorrect " - "logprobs for prompt tokens, tokens, please disable " - "it when the requests need prompt logprobs" - ) - - num_reqs = self.input_batch.num_reqs - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) - max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) - num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens - - ( - logits_indices, - spec_decode_metadata, - ) = self._prepare_inputs( + if has_ec_transfer() and get_ec_transfer().is_producer: + with self.maybe_get_ec_connector_output( scheduler_output, - num_scheduled_tokens_np, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output(scheduler_output) + + if not num_scheduled_tokens: + if ( + self.parallel_config.distributed_executor_backend + == "external_launcher" + and self.parallel_config.data_parallel_size > 1 + ): + # this is a corner case when both external launcher + # and DP are enabled, num_scheduled_tokens could be + # 0, and has_unfinished_requests in the outer loop + # returns True. before returning early here we call + # dummy run to ensure coordinate_batch_across_dp + # is called into to avoid out of sync issues. + self._dummy_run(1) + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output, self.vllm_config) + + if self.cache_config.kv_sharing_fast_prefill: + assert not self.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect " + "logprobs for prompt tokens, tokens, please disable " + "it when the requests need prompt logprobs" ) - cascade_attn_prefix_lens = None - # Disable cascade attention when using microbatching (DBO) - if self.cascade_attn_enabled and not self.parallel_config.use_ubatching: - # Pre-compute cascade attention prefix lengths - cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( - num_scheduled_tokens_np, - self.input_batch.num_computed_tokens_cpu[:num_reqs], - scheduler_output.num_common_prefix_blocks, - ) + num_reqs = self.input_batch.num_reqs + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens - ( - cudagraph_mode, - batch_desc, - should_ubatch, - num_tokens_across_dp, - cudagraph_stats, - ) = self._determine_batch_execution_and_padding( + logits_indices, spec_decode_metadata = self._prepare_inputs( + scheduler_output, + num_scheduled_tokens_np, + ) + + cascade_attn_prefix_lens = None + # Disable cascade attention when using microbatching (DBO) + if self.cascade_attn_enabled and not self.parallel_config.use_ubatching: + # Pre-compute cascade attention prefix lengths + cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( + num_scheduled_tokens_np, + self.input_batch.num_computed_tokens_cpu[:num_reqs], + scheduler_output.num_common_prefix_blocks, + ) + + ( + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + cudagraph_stats, + ) = self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens_np, + max_num_scheduled_tokens=max_num_scheduled_tokens, + use_cascade_attn=cascade_attn_prefix_lens is not None, + num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs), + ) + + logger.debug( + "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " + "should_ubatch: %s, num_tokens_across_dp: %s", + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + ) + + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = ( + batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ) + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens_np, + num_tokens_padded, + num_reqs_padded, + self.parallel_config.num_ubatches, + ) + + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, + ) + + pad_attn = cudagraph_mode == CUDAGraphMode.FULL + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices + + attn_metadata, spec_decode_common_attn_metadata = ( + self._build_attention_metadata( num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, num_reqs=num_reqs, - num_scheduled_tokens_np=num_scheduled_tokens_np, - max_num_scheduled_tokens=max_num_scheduled_tokens, - use_cascade_attn=cascade_attn_prefix_lens is not None, - num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs), - ) - - logger.debug( - "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " - "should_ubatch: %s, num_tokens_across_dp: %s", - cudagraph_mode, - batch_desc, - should_ubatch, - num_tokens_across_dp, - ) - - num_tokens_padded = batch_desc.num_tokens - num_reqs_padded = ( - batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs - ) - ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( - should_ubatch, - num_scheduled_tokens_np, - num_tokens_padded, - num_reqs_padded, - self.parallel_config.num_ubatches, - ) - - logger.debug( - "ubatch_slices: %s, ubatch_slices_padded: %s", - ubatch_slices, - ubatch_slices_padded, - ) - - pad_attn = cudagraph_mode == CUDAGraphMode.FULL - - use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 - ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices - - (attn_metadata, spec_decode_common_attn_metadata) = ( - self._build_attention_metadata( - num_tokens=num_tokens_unpadded, - num_tokens_padded=num_tokens_padded if pad_attn else None, - num_reqs=num_reqs, - num_reqs_padded=num_reqs_padded if pad_attn else None, - max_query_len=max_num_scheduled_tokens, - ubatch_slices=ubatch_slices_attn, - logits_indices=logits_indices, - use_spec_decode=use_spec_decode, - num_scheduled_tokens=scheduler_output.num_scheduled_tokens, - cascade_attn_prefix_lens=cascade_attn_prefix_lens, - ) + num_reqs_padded=num_reqs_padded if pad_attn else None, + max_query_len=max_num_scheduled_tokens, + ubatch_slices=ubatch_slices_attn, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, ) + ) ( input_ids,