[Model Runner V2] support dp & ep for spec decoding (#35294)
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com> Co-authored-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
@@ -57,10 +57,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
|
||||
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
||||
from vllm.v1.worker.gpu.dp_utils import (
|
||||
get_cudagraph_and_dp_padding,
|
||||
make_num_tokens_across_dp,
|
||||
)
|
||||
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
|
||||
from vllm.v1.worker.gpu.input_batch import (
|
||||
InputBatch,
|
||||
InputBuffers,
|
||||
@@ -265,7 +262,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
prepare_communication_buffer_for_model(self.model)
|
||||
if self.speculator is not None:
|
||||
prepare_communication_buffer_for_model(self.speculator)
|
||||
prepare_communication_buffer_for_model(self.speculator.model)
|
||||
|
||||
# Initialize the components that require the model.
|
||||
self.model_state = init_model_state(
|
||||
@@ -382,8 +379,41 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return None, None
|
||||
|
||||
assert self.execute_model_state is not None
|
||||
input_batch, _, _, _, hidden_states, _, _ = self.execute_model_state
|
||||
(
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
) = self.execute_model_state
|
||||
self.execute_model_state = None
|
||||
|
||||
# dummy run the eagle speculator's propose to ensure DP/EP sync.
|
||||
if self.speculator is not None:
|
||||
self.speculator.propose(
|
||||
input_batch=input_batch,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings_by_layer,
|
||||
last_hidden_states=hidden_states,
|
||||
aux_hidden_states=aux_hidden_states,
|
||||
num_sampled=torch.ones(
|
||||
input_batch.num_reqs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
num_rejected=torch.zeros(
|
||||
input_batch.num_reqs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
last_sampled=self.req_states.last_sampled_tokens,
|
||||
next_prefill_tokens=self.req_states.next_prefill_tokens,
|
||||
temperature=self.sampler.sampling_states.temperature.gpu,
|
||||
seeds=self.sampler.sampling_states.seeds.gpu,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
dummy_run=True,
|
||||
skip_attn_for_dummy_run=skip_attn,
|
||||
)
|
||||
|
||||
assert hidden_states is not None # Last PP rank always has hidden_states
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
@@ -431,17 +461,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
self._dummy_pooler_run(hidden_states)
|
||||
|
||||
if self.speculator is not None:
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(
|
||||
self.parallel_config.data_parallel_size, self.max_num_tokens
|
||||
)
|
||||
self.speculator.run_model(
|
||||
self.max_num_tokens,
|
||||
attn_metadata=None,
|
||||
slot_mappings=None,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, sample_hidden_states
|
||||
gc.collect()
|
||||
@@ -979,6 +998,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
)
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
@@ -1005,6 +1025,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
) = self.execute_model_state
|
||||
self.execute_model_state = None
|
||||
|
||||
@@ -1078,6 +1099,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.req_states.next_prefill_tokens,
|
||||
self.sampler.sampling_states.temperature.gpu,
|
||||
self.sampler.sampling_states.seeds.gpu,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
|
||||
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
|
||||
|
||||
@@ -55,6 +55,26 @@ class EagleCudaGraphManager:
|
||||
def get_cudagraph_size(self, num_tokens: int) -> int | None:
|
||||
return self.cudagraph_sizes.get(num_tokens)
|
||||
|
||||
def get_cudagraph_runtime_mode(
|
||||
self, num_tokens: int
|
||||
) -> tuple[CUDAGraphMode, int | None]:
|
||||
cudagraph_size = self.get_cudagraph_size(num_tokens)
|
||||
if cudagraph_size is None:
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
else:
|
||||
cudagraph_mode = self.cudagraph_mode
|
||||
|
||||
if (
|
||||
cudagraph_mode == CUDAGraphMode.FULL
|
||||
and cudagraph_size is not None
|
||||
and cudagraph_size not in self.graphs
|
||||
):
|
||||
# If graph wasn't captured yet, fall back to eager.
|
||||
# This might happen when the dummy run is called before capture.
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
cudagraph_size = None
|
||||
return cudagraph_mode, cudagraph_size
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_slot_mappings_by_layer,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
@@ -48,6 +49,10 @@ class EagleSpeculator:
|
||||
self.vocab_size = self.draft_model_config.get_vocab_size()
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
# DP configuration
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
@@ -122,8 +127,8 @@ class EagleSpeculator:
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens_padded: int,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
) -> None:
|
||||
@@ -164,9 +169,10 @@ class EagleSpeculator:
|
||||
self.hidden_states,
|
||||
self.max_model_len,
|
||||
)
|
||||
self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
if attn_metadata is not None:
|
||||
self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
if self.num_speculative_steps == 1:
|
||||
@@ -203,6 +209,9 @@ class EagleSpeculator:
|
||||
temperature: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
seeds: torch.Tensor,
|
||||
num_tokens_across_dp: torch.Tensor | None = None,
|
||||
dummy_run: bool = False,
|
||||
skip_attn_for_dummy_run: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
|
||||
# number of rejected tokens, we maintain the size of eagle's input_ids and
|
||||
@@ -236,7 +245,7 @@ class EagleSpeculator:
|
||||
num_tokens,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp=None, # FIXME
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
@@ -282,48 +291,64 @@ class EagleSpeculator:
|
||||
self.max_model_len,
|
||||
self.max_num_reqs,
|
||||
)
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
|
||||
cudagraph_mode = self.cudagraph_manager.cudagraph_mode
|
||||
if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL:
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
cudagraph_mode, cudagraph_size = (
|
||||
self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs)
|
||||
)
|
||||
num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
get_cudagraph_and_dp_padding(
|
||||
num_reqs,
|
||||
cudagraph_size,
|
||||
cudagraph_mode.value,
|
||||
self.dp_size,
|
||||
self.dp_rank,
|
||||
)
|
||||
)
|
||||
cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode)
|
||||
if cudagraph_mode == CUDAGraphMode.FULL:
|
||||
# Run full CUDA graph.
|
||||
self.cudagraph_manager.run_fullgraph(cudagraph_size)
|
||||
self.cudagraph_manager.run_fullgraph(num_tokens_padded)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
# Run eager or piecewise CUDA graph.
|
||||
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs
|
||||
query_start_loc_cpu = torch.arange(
|
||||
num_reqs + 1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
||||
attn_metadata_updated = None
|
||||
slot_mappings_updated = None
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
query_start_loc_cpu = torch.arange(
|
||||
num_reqs + 1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
attn_metadata_updated = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_reqs,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=1,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
slot_mappings_updated = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_reqs,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=1,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
self.generate_draft(
|
||||
num_reqs,
|
||||
num_tokens_padded,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
num_tokens_across_dp=None, # FIXME
|
||||
attn_metadata_updated,
|
||||
slot_mappings_updated,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_mode,
|
||||
)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
Reference in New Issue
Block a user