[Model Runner V2] support dp & ep for spec decoding (#35294)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Co-authored-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
zhrrr
2026-03-04 07:19:45 +08:00
committed by GitHub
parent 12b38c0f45
commit 97286a20ed
3 changed files with 123 additions and 56 deletions

View File

@@ -57,10 +57,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import (
get_cudagraph_and_dp_padding,
make_num_tokens_across_dp,
)
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
@@ -265,7 +262,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
prepare_communication_buffer_for_model(self.model)
if self.speculator is not None:
prepare_communication_buffer_for_model(self.speculator)
prepare_communication_buffer_for_model(self.speculator.model)
# Initialize the components that require the model.
self.model_state = init_model_state(
@@ -382,8 +379,41 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return None, None
assert self.execute_model_state is not None
input_batch, _, _, _, hidden_states, _, _ = self.execute_model_state
(
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state
self.execute_model_state = None
# dummy run the eagle speculator's propose to ensure DP/EP sync.
if self.speculator is not None:
self.speculator.propose(
input_batch=input_batch,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
last_hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states,
num_sampled=torch.ones(
input_batch.num_reqs, dtype=torch.int32, device=self.device
),
num_rejected=torch.zeros(
input_batch.num_reqs, dtype=torch.int32, device=self.device
),
last_sampled=self.req_states.last_sampled_tokens,
next_prefill_tokens=self.req_states.next_prefill_tokens,
temperature=self.sampler.sampling_states.temperature.gpu,
seeds=self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
)
assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
@@ -431,17 +461,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
self._dummy_pooler_run(hidden_states)
if self.speculator is not None:
num_tokens_across_dp = make_num_tokens_across_dp(
self.parallel_config.data_parallel_size, self.max_num_tokens
)
self.speculator.run_model(
self.max_num_tokens,
attn_metadata=None,
slot_mappings=None,
num_tokens_across_dp=num_tokens_across_dp,
)
torch.cuda.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
@@ -979,6 +998,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
)
if not self.is_last_pp_rank:
@@ -1005,6 +1025,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state
self.execute_model_state = None
@@ -1078,6 +1099,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
)
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)

View File

@@ -55,6 +55,26 @@ class EagleCudaGraphManager:
def get_cudagraph_size(self, num_tokens: int) -> int | None:
return self.cudagraph_sizes.get(num_tokens)
def get_cudagraph_runtime_mode(
self, num_tokens: int
) -> tuple[CUDAGraphMode, int | None]:
cudagraph_size = self.get_cudagraph_size(num_tokens)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
else:
cudagraph_mode = self.cudagraph_mode
if (
cudagraph_mode == CUDAGraphMode.FULL
and cudagraph_size is not None
and cudagraph_size not in self.graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_size = None
return cudagraph_mode, cudagraph_size
def capture_graph(
self,
num_tokens: int,

View File

@@ -16,6 +16,7 @@ from vllm.v1.worker.gpu.attn_utils import (
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
@@ -48,6 +49,10 @@ class EagleSpeculator:
self.vocab_size = self.draft_model_config.get_vocab_size()
self.dtype = vllm_config.model_config.dtype
# DP configuration
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
@@ -122,8 +127,8 @@ class EagleSpeculator:
self,
num_reqs: int,
num_tokens_padded: int,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
) -> None:
@@ -164,9 +169,10 @@ class EagleSpeculator:
self.hidden_states,
self.max_model_len,
)
self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
if attn_metadata is not None:
self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
def capture_model(self) -> None:
if self.num_speculative_steps == 1:
@@ -203,6 +209,9 @@ class EagleSpeculator:
temperature: torch.Tensor,
# [max_num_reqs]
seeds: torch.Tensor,
num_tokens_across_dp: torch.Tensor | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
) -> torch.Tensor:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and
@@ -236,7 +245,7 @@ class EagleSpeculator:
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp=None, # FIXME
num_tokens_across_dp=num_tokens_across_dp,
)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
@@ -282,48 +291,64 @@ class EagleSpeculator:
self.max_model_len,
self.max_num_reqs,
)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
cudagraph_mode = self.cudagraph_manager.cudagraph_mode
if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL:
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
cudagraph_mode, cudagraph_size = (
self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs)
)
num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = (
get_cudagraph_and_dp_padding(
num_reqs,
cudagraph_size,
cudagraph_mode.value,
self.dp_size,
self.dp_rank,
)
)
cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode)
if cudagraph_mode == CUDAGraphMode.FULL:
# Run full CUDA graph.
self.cudagraph_manager.run_fullgraph(cudagraph_size)
self.cudagraph_manager.run_fullgraph(num_tokens_padded)
return self.draft_tokens[:num_reqs]
# Run eager or piecewise CUDA graph.
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs
query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu"
)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
attn_metadata_updated = None
slot_mappings_updated = None
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu"
)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata_updated = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
slot_mappings_updated = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
# FIXME(woosuk): This is UNSAFE!!
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
self.generate_draft(
num_reqs,
num_tokens_padded,
attn_metadata,
slot_mappings_by_layer,
num_tokens_across_dp=None, # FIXME
attn_metadata_updated,
slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_mode,
)
return self.draft_tokens[:num_reqs]