[BugFix] Re-fix async multimodal cpu tensor race condition (#31373)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: njhill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2025-12-28 03:05:08 -08:00
committed by GitHub
parent 573dd0e6f0
commit 094fcce250

View File

@@ -3058,131 +3058,129 @@ class GPUModelRunner(
scheduler_output = deepcopy(scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with record_function_or_nullcontext("gpu_model_runner: preprocess"):
with self.synchronize_input_prep():
# Update persistent batch states.
self._update_states(scheduler_output)
with (
record_function_or_nullcontext("gpu_model_runner: preprocess"),
self.synchronize_input_prep(),
):
# Update persistent batch states.
self._update_states(scheduler_output)
if has_ec_transfer() and get_ec_transfer().is_producer:
with self.maybe_get_ec_connector_output(
scheduler_output,
encoder_cache=self.encoder_cache,
) as ec_connector_output:
self._execute_mm_encoder(scheduler_output)
return make_empty_encoder_model_runner_output(scheduler_output)
if not num_scheduled_tokens:
if (
self.parallel_config.distributed_executor_backend
== "external_launcher"
and self.parallel_config.data_parallel_size > 1
):
# this is a corner case when both external launcher
# and DP are enabled, num_scheduled_tokens could be
# 0, and has_unfinished_requests in the outer loop
# returns True. before returning early here we call
# dummy run to ensure coordinate_batch_across_dp
# is called into to avoid out of sync issues.
self._dummy_run(1)
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(
scheduler_output, self.vllm_config
)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect "
"logprobs for prompt tokens, tokens, please disable "
"it when the requests need prompt logprobs"
)
num_reqs = self.input_batch.num_reqs
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
(
logits_indices,
spec_decode_metadata,
) = self._prepare_inputs(
if has_ec_transfer() and get_ec_transfer().is_producer:
with self.maybe_get_ec_connector_output(
scheduler_output,
num_scheduled_tokens_np,
encoder_cache=self.encoder_cache,
) as ec_connector_output:
self._execute_mm_encoder(scheduler_output)
return make_empty_encoder_model_runner_output(scheduler_output)
if not num_scheduled_tokens:
if (
self.parallel_config.distributed_executor_backend
== "external_launcher"
and self.parallel_config.data_parallel_size > 1
):
# this is a corner case when both external launcher
# and DP are enabled, num_scheduled_tokens could be
# 0, and has_unfinished_requests in the outer loop
# returns True. before returning early here we call
# dummy run to ensure coordinate_batch_across_dp
# is called into to avoid out of sync issues.
self._dummy_run(1)
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect "
"logprobs for prompt tokens, tokens, please disable "
"it when the requests need prompt logprobs"
)
cascade_attn_prefix_lens = None
# Disable cascade attention when using microbatching (DBO)
if self.cascade_attn_enabled and not self.parallel_config.use_ubatching:
# Pre-compute cascade attention prefix lengths
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
num_scheduled_tokens_np,
self.input_batch.num_computed_tokens_cpu[:num_reqs],
scheduler_output.num_common_prefix_blocks,
)
num_reqs = self.input_batch.num_reqs
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
(
cudagraph_mode,
batch_desc,
should_ubatch,
num_tokens_across_dp,
cudagraph_stats,
) = self._determine_batch_execution_and_padding(
logits_indices, spec_decode_metadata = self._prepare_inputs(
scheduler_output,
num_scheduled_tokens_np,
)
cascade_attn_prefix_lens = None
# Disable cascade attention when using microbatching (DBO)
if self.cascade_attn_enabled and not self.parallel_config.use_ubatching:
# Pre-compute cascade attention prefix lengths
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
num_scheduled_tokens_np,
self.input_batch.num_computed_tokens_cpu[:num_reqs],
scheduler_output.num_common_prefix_blocks,
)
(
cudagraph_mode,
batch_desc,
should_ubatch,
num_tokens_across_dp,
cudagraph_stats,
) = self._determine_batch_execution_and_padding(
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs,
num_scheduled_tokens_np=num_scheduled_tokens_np,
max_num_scheduled_tokens=max_num_scheduled_tokens,
use_cascade_attn=cascade_attn_prefix_lens is not None,
num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs),
)
logger.debug(
"Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
"should_ubatch: %s, num_tokens_across_dp: %s",
cudagraph_mode,
batch_desc,
should_ubatch,
num_tokens_across_dp,
)
num_tokens_padded = batch_desc.num_tokens
num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
)
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
should_ubatch,
num_scheduled_tokens_np,
num_tokens_padded,
num_reqs_padded,
self.parallel_config.num_ubatches,
)
logger.debug(
"ubatch_slices: %s, ubatch_slices_padded: %s",
ubatch_slices,
ubatch_slices_padded,
)
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
attn_metadata, spec_decode_common_attn_metadata = (
self._build_attention_metadata(
num_tokens=num_tokens_unpadded,
num_tokens_padded=num_tokens_padded if pad_attn else None,
num_reqs=num_reqs,
num_scheduled_tokens_np=num_scheduled_tokens_np,
max_num_scheduled_tokens=max_num_scheduled_tokens,
use_cascade_attn=cascade_attn_prefix_lens is not None,
num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs),
)
logger.debug(
"Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
"should_ubatch: %s, num_tokens_across_dp: %s",
cudagraph_mode,
batch_desc,
should_ubatch,
num_tokens_across_dp,
)
num_tokens_padded = batch_desc.num_tokens
num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
)
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
should_ubatch,
num_scheduled_tokens_np,
num_tokens_padded,
num_reqs_padded,
self.parallel_config.num_ubatches,
)
logger.debug(
"ubatch_slices: %s, ubatch_slices_padded: %s",
ubatch_slices,
ubatch_slices_padded,
)
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
(attn_metadata, spec_decode_common_attn_metadata) = (
self._build_attention_metadata(
num_tokens=num_tokens_unpadded,
num_tokens_padded=num_tokens_padded if pad_attn else None,
num_reqs=num_reqs,
num_reqs_padded=num_reqs_padded if pad_attn else None,
max_query_len=max_num_scheduled_tokens,
ubatch_slices=ubatch_slices_attn,
logits_indices=logits_indices,
use_spec_decode=use_spec_decode,
num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
)
num_reqs_padded=num_reqs_padded if pad_attn else None,
max_query_len=max_num_scheduled_tokens,
ubatch_slices=ubatch_slices_attn,
logits_indices=logits_indices,
use_spec_decode=use_spec_decode,
num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
)
)
(
input_ids,