diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 3d626a2c5..f97e223df 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -101,9 +101,6 @@ class CudaGraphManager: kv_cache_config, ) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) - slot_mappings_by_layer = build_slot_mappings_by_layer( - slot_mappings, kv_cache_config - ) # Warm up. with set_forward_context( @@ -112,7 +109,7 @@ class CudaGraphManager: num_tokens=num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, - slot_mapping=slot_mappings_by_layer, + slot_mapping=slot_mappings, ): hidden_states = model( input_ids=input_ids, @@ -132,7 +129,7 @@ class CudaGraphManager: num_tokens=num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, - slot_mapping=slot_mappings_by_layer, + slot_mapping=slot_mappings, ), torch.cuda.graph(graph, self.pool), ): @@ -252,7 +249,7 @@ def prepare_inputs_to_capture( attn_metadata_builders: list[AttentionMetadataBuilder], max_model_len: int, kv_cache_config: KVCacheConfig, -) -> tuple[dict[str, Any], torch.Tensor]: +) -> tuple[dict[str, Any], dict[str, torch.Tensor]]: num_tokens_per_req = num_tokens // num_reqs query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req @@ -269,6 +266,9 @@ def prepare_inputs_to_capture( input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables] slot_mappings = block_tables.slot_mappings[:, :num_tokens] + slot_mappings_by_layer = build_slot_mappings_by_layer( + slot_mappings, kv_cache_config + ) attn_metadata = build_attn_metadata( attn_metadata_builders=attn_metadata_builders, @@ -282,4 +282,4 @@ def prepare_inputs_to_capture( slot_mappings=slot_mappings, kv_cache_config=kv_cache_config, ) - return attn_metadata, slot_mappings + return attn_metadata, slot_mappings_by_layer diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index d6069c4cf..b3ab15178 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -66,6 +66,8 @@ class InputBatch: # layer_name -> Metadata attn_metadata: dict[str, Any] + # layer_name -> slot_mapping + slot_mappings: dict[str, torch.Tensor] # [total_num_logits] logits_indices: torch.Tensor @@ -133,6 +135,7 @@ class InputBatch: mrope_positions=None, inputs_embeds=None, attn_metadata=None, # type: ignore + slot_mappings=None, # type: ignore logits_indices=logits_indices, cu_num_logits=cu_num_logits, cu_num_logits_np=cu_num_logits_np, diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 75d4c4e00..0206fb9b2 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -269,6 +269,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): slot_mappings = self.block_tables.get_dummy_slot_mappings( input_batch.num_tokens ) + slot_mappings_by_layer = build_slot_mappings_by_layer( + slot_mappings, self.kv_cache_config + ) attn_metadata = build_attn_metadata( attn_metadata_builders=self.attn_metadata_builders, num_reqs=input_batch.num_reqs, @@ -282,6 +285,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): kv_cache_config=self.kv_cache_config, ) input_batch.attn_metadata = attn_metadata + input_batch.slot_mappings = slot_mappings_by_layer @torch.inference_mode() def _dummy_run( @@ -345,6 +349,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.speculator.run_model( self.max_num_tokens, attn_metadata=None, + slot_mappings=None, num_tokens_across_dp=num_tokens_across_dp, ) torch.cuda.synchronize() @@ -615,6 +620,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): query_start_loc, self.input_buffers.positions[:num_tokens], ) + # Layer name -> slot mapping. + slot_mappings_by_layer = build_slot_mappings_by_layer( + slot_mappings, self.kv_cache_config + ) # Layer name -> attention metadata. attn_metadata = build_attn_metadata( @@ -655,6 +664,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): mrope_positions=mrope_positions, inputs_embeds=None, attn_metadata=attn_metadata, + slot_mappings=slot_mappings_by_layer, logits_indices=logits_indices, cu_num_logits=cu_num_logits, cu_num_logits_np=cu_num_logits_np, @@ -882,14 +892,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.uses_mrope: assert input_batch.mrope_positions is not None positions = input_batch.mrope_positions - slot_mappings = self.block_tables.compute_slot_mappings( - input_batch.idx_mapping, - input_batch.query_start_loc, - input_batch.positions[: input_batch.num_tokens], - ) - slot_mappings_by_layer = build_slot_mappings_by_layer( - slot_mappings, self.kv_cache_config - ) with set_forward_context( input_batch.attn_metadata, self.vllm_config, @@ -897,7 +899,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # TODO(woosuk): Support piecewise CUDA graph. cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, - slot_mapping=slot_mappings_by_layer, + slot_mapping=input_batch.slot_mappings, ): self.kv_connector.pre_forward(scheduler_output) hidden_states = self.model( diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py index f86b53793..79343e54d 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle.py @@ -13,7 +13,10 @@ from vllm.model_executor.model_loader import get_model from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.worker.gpu.attn_utils import build_attn_metadata +from vllm.v1.worker.gpu.attn_utils import ( + build_attn_metadata, + build_slot_mappings_by_layer, +) from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample @@ -108,7 +111,8 @@ class EagleSpeculator: def run_model( self, num_tokens: int, - attn_metadata: dict[str, Any], + attn_metadata: dict[str, Any] | None, + slot_mappings: dict[str, torch.Tensor] | None, num_tokens_across_dp: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: with set_forward_context( @@ -117,6 +121,7 @@ class EagleSpeculator: num_tokens=num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, + slot_mapping=slot_mappings, ): ret_hidden_states = self.model( input_ids=self.input_buffers.input_ids[:num_tokens], @@ -134,6 +139,7 @@ class EagleSpeculator: self, num_reqs: int, attn_metadata: dict[str, Any], + slot_mappings: dict[str, torch.Tensor], num_tokens_across_dp: torch.Tensor | None, ) -> None: pos = self.input_buffers.positions[:num_reqs] @@ -142,7 +148,7 @@ class EagleSpeculator: for step in range(1, self.num_speculative_steps): # Run the eagle model. last_hidden_states, hidden_states = self.run_model( - num_reqs, attn_metadata, num_tokens_across_dp + num_reqs, attn_metadata, slot_mappings, num_tokens_across_dp ) logits = self.model.compute_logits(last_hidden_states) @@ -235,6 +241,7 @@ class EagleSpeculator: last_hidden_states, hidden_states = self.run_model( num_tokens, input_batch.attn_metadata, + input_batch.slot_mappings, num_tokens_across_dp=None, # FIXME ) sample_hidden_states = last_hidden_states[last_token_indices] @@ -311,7 +318,12 @@ class EagleSpeculator: slot_mappings=slot_mappings, kv_cache_config=self.kv_cache_config, ) - self.generate_draft(num_reqs, attn_metadata, num_tokens_across_dp=None) # FIXME + slot_mappings_by_layer = build_slot_mappings_by_layer( + slot_mappings, self.kv_cache_config + ) + self.generate_draft( + num_reqs, attn_metadata, slot_mappings_by_layer, num_tokens_across_dp=None + ) # FIXME return self.draft_tokens[:num_reqs] diff --git a/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py b/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py index c4a511778..33873418c 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py @@ -69,7 +69,7 @@ class EagleCudaGraphManager: kv_cache_config: KVCacheConfig, ) -> None: num_reqs = min(num_tokens, self.max_num_reqs) - attn_metadata = prepare_inputs_to_capture( + attn_metadata, slot_mappings = prepare_inputs_to_capture( num_reqs, num_tokens, input_buffers, @@ -81,13 +81,13 @@ class EagleCudaGraphManager: num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) # Warm up. - generate_fn(num_tokens, attn_metadata, num_tokens_across_dp) + generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp) # Capture the graph. assert num_tokens not in self.graphs graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, self.pool): - generate_fn(num_tokens, attn_metadata, num_tokens_across_dp) + generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp) self.graphs[num_tokens] = graph @torch.inference_mode()