diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index c9ae28abf..b4e7773cd 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -3,7 +3,6 @@ from collections.abc import Callable from typing import Any -import numpy as np import torch import torch.nn as nn from tqdm import tqdm @@ -15,13 +14,11 @@ from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.model_executor.offloader.base import get_offloader from vllm.utils.math_utils import cdiv from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.worker.gpu.attn_utils import ( - build_attn_metadata, - build_slot_mappings_by_layer, -) +from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer from vllm.v1.worker.gpu.block_table import BlockTables +from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp -from vllm.v1.worker.gpu.input_batch import InputBuffers +from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.utils import AttentionGroup @@ -123,14 +120,11 @@ class CudaGraphManager: attn_metadata, slot_mappings = prepare_inputs_to_capture( num_reqs, num_tokens, + model_state, input_buffers, block_tables, attn_groups, - self.max_model_len, kv_cache_config, - uniform_decode_query_len=( - self.uniform_decode_query_len if uniform_decode else 0 - ), ) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) @@ -393,51 +387,36 @@ def capture_graphs( def prepare_inputs_to_capture( num_reqs: int, num_tokens: int, + model_state: ModelState, input_buffers: InputBuffers, block_tables: BlockTables, attn_groups: list[list[AttentionGroup]], - max_model_len: int, kv_cache_config: KVCacheConfig, - uniform_decode_query_len: int = 0, ) -> tuple[dict[str, Any], dict[str, torch.Tensor]]: - if uniform_decode_query_len > 0: - num_tokens_per_req = uniform_decode_query_len - else: - num_tokens_per_req = num_tokens // num_reqs - - query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req - query_start_loc_np[-1] = num_tokens - query_start_loc_cpu = torch.from_numpy(query_start_loc_np) - input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu - input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens - query_start_loc = input_buffers.query_start_loc[: num_reqs + 1] - - # HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens - # rather than max_model_len. - input_buffers.seq_lens[:num_reqs] = num_tokens - input_buffers.seq_lens[num_reqs:] = 0 - - input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens - input_buffers.dcp_local_seq_lens[num_reqs:] = 0 - + input_batch = InputBatch.make_dummy(num_reqs, num_tokens, input_buffers) input_block_tables = block_tables.get_dummy_block_tables(num_reqs) slot_mappings = block_tables.get_dummy_slot_mappings(num_tokens) slot_mappings_by_layer = build_slot_mappings_by_layer( slot_mappings, kv_cache_config ) - attn_metadata = build_attn_metadata( - attn_groups=attn_groups, - num_reqs=num_reqs, - num_tokens=num_tokens, - query_start_loc_gpu=query_start_loc, - query_start_loc_cpu=query_start_loc_cpu, - max_query_len=num_tokens_per_req, - seq_lens=input_buffers.seq_lens, - max_seq_len=max_model_len, - block_tables=input_block_tables, - slot_mappings=slot_mappings, - kv_cache_config=kv_cache_config, - dcp_local_seq_lens=input_buffers.dcp_local_seq_lens, + # HACK(woosuk): Special handling for DCP. + if block_tables.cp_size > 1: + prepare_dcp_local_seq_lens( + input_buffers.dcp_local_seq_lens, + input_batch.seq_lens, + num_reqs, + block_tables.cp_size, + block_tables.cp_rank, + block_tables.cp_interleave, + ) + input_batch.dcp_local_seq_lens = input_buffers.dcp_local_seq_lens[:num_reqs] + + attn_metadata = model_state.prepare_attn( + input_batch, + input_block_tables, + slot_mappings, + attn_groups, + kv_cache_config, ) return attn_metadata, slot_mappings_by_layer diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 974f117d2..1ca87612e 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -82,14 +82,16 @@ class InputBatch: num_reqs: int, num_tokens: int, input_buffers: InputBuffers, - device: torch.device, ) -> "InputBatch": assert 0 < num_reqs <= num_tokens + device = input_buffers.device + req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)] idx_mapping_np = np.arange(num_reqs, dtype=np.int32) idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) expanded_idx_mapping = idx_mapping expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device) + num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32) num_scheduled_tokens[-1] += num_tokens % num_reqs assert int(num_scheduled_tokens.sum()) == num_tokens @@ -115,7 +117,6 @@ class InputBatch: input_ids = input_buffers.input_ids[:num_tokens].zero_() positions = input_buffers.positions[:num_tokens].zero_() - # attn_metadata = defaultdict(lambda: None) logits_indices = query_start_loc[1:] - 1 cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32) cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 63fa8fd65..35dd617ee 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -311,6 +311,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.speculator is not None: # HACK(woosuk) self.speculator.set_attn( + self.model_state, self.kv_cache_config, self.attn_groups, self.block_tables, @@ -880,10 +881,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # No actual tokens to run. A dummy run for DP or memory profiling. num_reqs = min(num_tokens_after_padding, self.max_num_reqs) input_batch = InputBatch.make_dummy( - num_reqs=num_reqs, - num_tokens=num_tokens_after_padding, - input_buffers=self.input_buffers, - device=self.device, + num_reqs, num_tokens_after_padding, self.input_buffers ) if not skip_attn_for_dummy_run: block_tables, slot_mappings = self.prepare_dummy_attn(input_batch) diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py b/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py index eda8c37d5..77dddf3ad 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py @@ -17,6 +17,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import ( ) from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.input_batch import InputBuffers +from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.utils import AttentionGroup @@ -59,6 +60,7 @@ class EagleCudaGraphManager: num_tokens: int, capture_cg_mode: CUDAGraphMode, generate_fn: Callable, + model_state: ModelState, input_buffers: InputBuffers, block_tables: BlockTables, attn_groups: list[list[AttentionGroup]], @@ -76,12 +78,11 @@ class EagleCudaGraphManager: attn_metadata, slot_mappings = prepare_inputs_to_capture( num_reqs, num_tokens, + model_state, input_buffers, block_tables, attn_groups, - self.max_model_len, kv_cache_config, - uniform_decode_query_len=1, ) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) @@ -158,6 +159,7 @@ class EagleCudaGraphManager: def capture( self, generate_fn: Callable, + model_state: ModelState, input_buffers: InputBuffers, block_tables: BlockTables, attn_groups: list[list[AttentionGroup]], @@ -173,6 +175,7 @@ class EagleCudaGraphManager: capture_cudagraph_mode=self.cudagraph_mode, desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})", generate_fn=generate_fn, + model_state=model_state, input_buffers=input_buffers, block_tables=block_tables, attn_groups=attn_groups, diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 74172ea18..9ea84386b 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -17,6 +17,7 @@ from vllm.v1.worker.gpu.attn_utils import ( ) from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers +from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model @@ -76,10 +77,12 @@ class EagleSpeculator: def set_attn( self, + model_state: ModelState, kv_cache_config: KVCacheConfig, attn_groups: list[list[AttentionGroup]], block_tables: BlockTables, ) -> None: + self.model_state = model_state self.kv_cache_config = kv_cache_config self.attn_groups = attn_groups self.block_tables = block_tables @@ -171,6 +174,7 @@ class EagleSpeculator: logger.info("Capturing model for Eagle speculator...") self.cudagraph_manager.capture( self.generate_draft, + self.model_state, self.input_buffers, self.block_tables, self.attn_groups,