break execute_model in gpu_model_runner into sub-functions for custom scopes (#24265)

Co-authored-by: Bangsheng Tang <bangsheng@meta.com>
This commit is contained in:
Bangsheng Tang
2025-09-06 14:02:47 -07:00
committed by GitHub
parent e68dc2f014
commit 848562bd49
3 changed files with 208 additions and 109 deletions

View File

@@ -69,7 +69,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
KVCacheGroupSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
DraftTokenIds, LogprobsLists, LogprobsTensors,
ModelRunnerOutput, SamplerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
@@ -79,7 +80,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
@@ -1587,31 +1588,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_connector_output=kv_connector_output,
)
@torch.inference_mode()
def execute_model(
def _preprocess(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.input_batch.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, tokens, please disable it when the requests "
"need prompt logprobs")
# Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len) = self._prepare_inputs(scheduler_output)
) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], torch.Tensor,
Optional[IntermediateTensors], dict[str, Any]]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
@@ -1683,75 +1666,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=uniform_decode)
cudagraph_runtime_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(batch_descriptor)
# Run the model.
# Use persistent buffers for CUDA graphs.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
), self.maybe_get_kv_connector_output(
scheduler_output) as kv_connector_output:
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output = \
self.parallel_config.distributed_executor_backend \
== "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
assert isinstance(hidden_states, IntermediateTensors)
if not broadcast_pp_output:
hidden_states.kv_connector_output = kv_connector_output
return hidden_states
get_pp_group().send_tensor_dict(hidden_states.tensors,
all_gather_group=get_tp_group())
logits = None
else:
if self.is_pooling_model:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, kv_connector_output)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
model_output_broadcast_data = {
"logits": logits.contiguous(),
} if logits is not None else {}
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
self.apply_grammar_bitmask(scheduler_output, logits)
return (
num_scheduled_tokens,
num_input_tokens,
num_tokens_across_dp,
input_ids,
inputs_embeds,
positions,
intermediate_tensors,
model_kwargs,
)
def _sample(
self, logits: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata]
) -> SamplerOutput:
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
@@ -1785,6 +1714,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
sampler_output.sampled_token_ids = output_token_ids
return sampler_output
def _bookkeeping_sync(
self, scheduler_output: "SchedulerOutput",
sampler_output: SamplerOutput, logits: Optional[torch.Tensor],
hidden_states: torch.Tensor, num_scheduled_tokens: int
) -> tuple[
dict[str, int],
Optional[LogprobsLists],
list[list[int]],
dict[str, Optional[LogprobsTensors]],
list[str],
dict[str, int],
list[int],
]:
num_nans_in_logits = {}
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
num_nans_in_logits = self._get_nans_in_logits(logits)
@@ -1827,6 +1771,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
invalid_req_indices = []
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
@@ -1892,20 +1837,159 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
if self.speculative_config:
assert spec_decode_common_attn_metadata is not None
self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
spec_decode_common_attn_metadata,
return (
num_nans_in_logits,
logprobs_lists,
valid_sampled_token_ids,
prompt_logprobs_dict,
req_ids_output_copy,
req_id_to_index_output_copy,
invalid_req_indices,
)
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
with record_function_or_nullcontext("Preprocess"):
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.input_batch.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, tokens, please disable it when the requests"
" need prompt logprobs")
# Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len) = self._prepare_inputs(scheduler_output)
(
num_scheduled_tokens,
num_input_tokens,
num_tokens_across_dp,
input_ids,
inputs_embeds,
positions,
intermediate_tensors,
model_kwargs,
) = self._preprocess(scheduler_output, intermediate_tensors)
uniform_decode = (max_query_len
== self.uniform_decode_query_len) and (
num_scheduled_tokens
== self.input_batch.num_reqs * max_query_len)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=uniform_decode)
cudagraph_runtime_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(batch_descriptor)
# Run the model.
# Use persistent buffers for CUDA graphs.
with (set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
), record_function_or_nullcontext("Forward"),
self.maybe_get_kv_connector_output(scheduler_output) as
kv_connector_output):
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
self.eplb_step()
with record_function_or_nullcontext("Postprocess"):
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output = \
self.parallel_config.distributed_executor_backend \
== "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
assert isinstance(hidden_states, IntermediateTensors)
if not broadcast_pp_output:
hidden_states.kv_connector_output = kv_connector_output
return hidden_states
get_pp_group().send_tensor_dict(
hidden_states.tensors, all_gather_group=get_tp_group())
logits = None
else:
if self.is_pooling_model:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np,
kv_connector_output)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
model_output_broadcast_data = {
"logits": logits.contiguous(),
} if logits is not None else {}
model_output_broadcast_data = get_pp_group(
).broadcast_tensor_dict(model_output_broadcast_data,
src=len(get_pp_group().ranks) - 1)
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
self.apply_grammar_bitmask(scheduler_output, logits)
with record_function_or_nullcontext("Sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
with record_function_or_nullcontext("Bookkeep"):
assert isinstance(hidden_states, torch.Tensor)
(
num_nans_in_logits,
logprobs_lists,
valid_sampled_token_ids,
prompt_logprobs_dict,
req_ids_output_copy,
req_id_to_index_output_copy,
invalid_req_indices,
) = self._bookkeeping_sync(scheduler_output, sampler_output,
logits, hidden_states,
num_scheduled_tokens)
if self.speculative_config:
assert spec_decode_common_attn_metadata is not None
with record_function_or_nullcontext("Draft"):
self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
self.input_batch.sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
spec_decode_common_attn_metadata,
)
with record_function_or_nullcontext("EPLB"):
self.eplb_step()
output = ModelRunnerOutput(
req_ids=req_ids_output_copy,
@@ -1923,7 +2007,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return AsyncGPUModelRunnerOutput(
model_runner_output=output,
sampled_token_ids=sampled_token_ids,
sampled_token_ids=sampler_output.sampled_token_ids,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
)