diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 87b8bbf18..75655258c 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any import numpy as np import torch @@ -60,6 +59,8 @@ class InputBatch: query_start_loc_np: np.ndarray # [num_reqs] seq_lens: torch.Tensor + # [num_reqs] + dcp_local_seq_lens: torch.Tensor | None # [num_tokens_after_padding] input_ids: torch.Tensor @@ -68,11 +69,6 @@ class InputBatch: # [num_tokens_after_padding, hidden_size] inputs_embeds: torch.Tensor | None - # layer_name -> Metadata - attn_metadata: dict[str, Any] - # layer_name -> slot_mapping - slot_mappings: dict[str, torch.Tensor] - # [total_num_logits] logits_indices: torch.Tensor # [num_reqs + 1] @@ -139,11 +135,10 @@ class InputBatch: query_start_loc=query_start_loc, query_start_loc_np=query_start_loc_np, seq_lens=seq_lens, + dcp_local_seq_lens=None, input_ids=input_ids, positions=positions, inputs_embeds=None, - attn_metadata=None, # type: ignore - slot_mappings=None, # type: ignore logits_indices=logits_indices, cu_num_logits=cu_num_logits, cu_num_logits_np=cu_num_logits_np, diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 949f09f54..7dcdaf1d2 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -46,7 +46,6 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.worker.cp_utils import check_attention_cp_compatibility from vllm.v1.worker.gpu.async_utils import AsyncOutput from vllm.v1.worker.gpu.attn_utils import ( - build_attn_metadata, build_slot_mappings_by_layer, get_kv_cache_spec, init_attn_backend, @@ -317,31 +316,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict) - def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None: - block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs) - slot_mappings = self.block_tables.get_dummy_slot_mappings( - input_batch.num_tokens - ) - slot_mappings_by_layer = build_slot_mappings_by_layer( - slot_mappings, self.kv_cache_config - ) - attn_metadata = build_attn_metadata( - attn_groups=self.attn_groups, - num_reqs=input_batch.num_reqs, - num_tokens=input_batch.num_tokens, - query_start_loc_gpu=input_batch.query_start_loc, - query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np), - max_query_len=input_batch.num_scheduled_tokens.max().item(), - seq_lens=input_batch.seq_lens, - max_seq_len=self.max_model_len, - block_tables=block_tables, - slot_mappings=slot_mappings, - kv_cache_config=self.kv_cache_config, - dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens, - ) - input_batch.attn_metadata = attn_metadata - input_batch.slot_mappings = slot_mappings_by_layer - @torch.inference_mode() def _dummy_run( self, num_tokens: int, *args, skip_attn: bool = True, **kwargs @@ -384,7 +358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): return None, None assert self.execute_model_state is not None - hidden_states, _, input_batch, _ = self.execute_model_state + input_batch, _, _, _, hidden_states, _, _ = self.execute_model_state self.execute_model_state = None assert hidden_states is not None # Last PP rank always has hidden_states sample_hidden_states = hidden_states[input_batch.logits_indices] @@ -546,7 +520,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.encoder_runner.add_request(req_id, new_req_data.mm_features) self.model_state.add_request(req_index, new_req_data) - self.block_tables.append_block_ids( req_index, new_req_data.block_ids, overwrite=True ) @@ -624,9 +597,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): idx_mapping, total_num_logits, cu_num_logits, max_expand_len ) - # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] - block_tables = self.block_tables.gather_block_tables(idx_mapping) - # Get query_start_loc. query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32) query_start_loc_np[0] = 0 @@ -635,11 +605,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Some attention backends like FA3 require query_start_loc to be non-decreasing. query_start_loc_np[num_reqs + 1 :] = num_tokens async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc) - query_start_loc_np = query_start_loc_np[: num_reqs + 1] - query_start_loc_cpu = torch.from_numpy(query_start_loc_np) query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] - max_query_len = num_scheduled_tokens.max().item() # Get prefill tokens if any. if self.req_states.any_prefills(idx_mapping_np): @@ -663,6 +630,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) seq_lens = self.input_buffers.seq_lens[:num_reqs] + dcp_local_seq_lens = None if self.use_dcp: # Prepare dcp local seq_lens. prepare_dcp_local_seq_lens( @@ -673,7 +641,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.dcp_rank, self.cp_interleave, ) - dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs] + dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs] # Some input token ids are directly read from the last sampled tokens # and draft tokens. Also, get the logits indices to sample tokens from. @@ -689,35 +657,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): total_num_logits, ) - # Compute slot mappings: [num_kv_cache_groups, num_tokens] - slot_mappings = self.block_tables.compute_slot_mappings( - idx_mapping, - query_start_loc, - self.input_buffers.positions[:num_tokens], - ) - # Layer name -> slot mapping. - slot_mappings_by_layer = build_slot_mappings_by_layer( - slot_mappings, self.kv_cache_config - ) - - # Layer name -> attention metadata. - attn_metadata = build_attn_metadata( - attn_groups=self.attn_groups, - num_reqs=num_reqs, - num_tokens=num_tokens, - query_start_loc_gpu=query_start_loc, - query_start_loc_cpu=query_start_loc_cpu, - max_query_len=max_query_len, - seq_lens=self.input_buffers.seq_lens, - max_seq_len=self.max_model_len, - block_tables=block_tables, - slot_mappings=slot_mappings, - kv_cache_config=self.kv_cache_config, - dcp_local_seq_lens=dcp_local_seq_lens, - ) - - input_ids = self.input_buffers.input_ids[:num_tokens_after_padding] - positions = self.input_buffers.positions[:num_tokens_after_padding] return InputBatch( req_ids=req_ids, num_reqs=num_reqs, @@ -732,17 +671,38 @@ class GPUModelRunner(LoRAModelRunnerMixin): query_start_loc=query_start_loc, query_start_loc_np=query_start_loc_np, seq_lens=seq_lens, - input_ids=input_ids, - positions=positions, + dcp_local_seq_lens=dcp_local_seq_lens, + input_ids=self.input_buffers.input_ids[:num_tokens_after_padding], + positions=self.input_buffers.positions[:num_tokens_after_padding], inputs_embeds=None, - attn_metadata=attn_metadata, - slot_mappings=slot_mappings_by_layer, logits_indices=logits_indices, cu_num_logits=cu_num_logits, cu_num_logits_np=cu_num_logits_np, has_structured_output_reqs=scheduler_output.has_structured_output_requests, ) + def prepare_attn( + self, input_batch: InputBatch + ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: + # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] + block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping) + # Compute slot mappings: [num_kv_cache_groups, num_tokens] + slot_mappings = self.block_tables.compute_slot_mappings( + input_batch.idx_mapping, + input_batch.query_start_loc, + input_batch.positions, + ) + return block_tables, slot_mappings + + def prepare_dummy_attn( + self, input_batch: InputBatch + ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: + block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs) + slot_mappings = self.block_tables.get_dummy_slot_mappings( + input_batch.num_tokens + ) + return block_tables, slot_mappings + @torch.inference_mode() def get_mm_embeddings( self, @@ -899,6 +859,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): input_batch = self.prepare_inputs( scheduler_output, num_tokens_after_padding ) + block_tables, slot_mappings = self.prepare_attn(input_batch) + if self.lora_config: # Activate LoRA adapters. lora_inputs = self.lora_state.make_lora_inputs( @@ -929,9 +891,28 @@ class GPUModelRunner(LoRAModelRunnerMixin): device=self.device, ) if not skip_attn_for_dummy_run: - self.prepare_dummy_attn_metadata(input_batch) + block_tables, slot_mappings = self.prepare_dummy_attn(input_batch) + else: + block_tables = None + slot_mappings = None # FIXME(woosuk): Fix warmup for LoRA. + attn_metadata = None + slot_mappings_by_layer = None + if not (dummy_run and skip_attn_for_dummy_run): + assert slot_mappings is not None + slot_mappings_by_layer = build_slot_mappings_by_layer( + slot_mappings, self.kv_cache_config + ) + assert block_tables is not None + attn_metadata = self.model_state.prepare_attn( + input_batch, + block_tables, + slot_mappings, + self.attn_groups, + self.kv_cache_config, + ) + model_inputs = { "input_ids": input_batch.input_ids, "positions": input_batch.positions, @@ -968,13 +949,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) with set_forward_context( - input_batch.attn_metadata, + attn_metadata, self.vllm_config, num_tokens=input_batch.num_tokens_after_padding, cudagraph_runtime_mode=cudagraph_runtime_mode, num_tokens_across_dp=num_tokens_across_dp, batch_descriptor=batch_descriptor, - slot_mapping=input_batch.slot_mappings, + slot_mapping=slot_mappings_by_layer, ): self.kv_connector.pre_forward(scheduler_output) model_output = self.model(**model_inputs) @@ -985,22 +966,23 @@ class GPUModelRunner(LoRAModelRunnerMixin): aux_hidden_states = None kv_connector_output = self.kv_connector.post_forward(scheduler_output) + self.execute_model_state = ( + input_batch, + model_inputs, + attn_metadata, + slot_mappings_by_layer, + hidden_states, + aux_hidden_states, + kv_connector_output, + ) if not self.is_last_pp_rank: # Non-last PP rank: return IntermediateTensors for sending. assert isinstance(hidden_states, IntermediateTensors) hidden_states.kv_connector_output = kv_connector_output - self.execute_model_state = (None, None, input_batch, kv_connector_output) return hidden_states - # Last rank (or no PP): hidden_states is a tensor for sampling. assert isinstance(hidden_states, torch.Tensor) - self.execute_model_state = ( - hidden_states, - aux_hidden_states, - input_batch, - kv_connector_output, - ) return None @torch.inference_mode() @@ -1010,9 +992,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.execute_model_state is None: # The prior execute_model call must have failed. return None - hidden_states, aux_hidden_states, input_batch, kv_connector_output = ( - self.execute_model_state - ) + ( + input_batch, + model_inputs, + attn_metadata, + slot_mappings_by_layer, + hidden_states, + aux_hidden_states, + kv_connector_output, + ) = self.execute_model_state self.execute_model_state = None if not self.is_last_pp_rank: @@ -1075,6 +1063,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.speculator is not None: draft_tokens = self.speculator.propose( input_batch, + attn_metadata, + slot_mappings_by_layer, hidden_states, aux_hidden_states, num_sampled, diff --git a/vllm/v1/worker/gpu/model_states.py b/vllm/v1/worker/gpu/model_states.py index 03574b2ad..838f177b3 100644 --- a/vllm/v1/worker/gpu/model_states.py +++ b/vllm/v1/worker/gpu/model_states.py @@ -1,13 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.v1.core.sched.output import NewRequestData +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState from vllm.v1.worker.gpu.states import RequestState +from vllm.v1.worker.utils import AttentionGroup class ModelState: @@ -72,3 +77,29 @@ class ModelState: return {} mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens] return {"positions": mrope_positions} + + def prepare_attn( + self, + input_batch: InputBatch, + block_tables: tuple[torch.Tensor, ...], + slot_mappings: torch.Tensor, + attn_groups: list[list[AttentionGroup]], + kv_cache_config: KVCacheConfig, + ) -> dict[str, Any]: + query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np) + max_query_len = input_batch.num_scheduled_tokens.max().item() + attn_metadata = build_attn_metadata( + attn_groups=attn_groups, + num_reqs=input_batch.num_reqs, + num_tokens=input_batch.num_tokens, + query_start_loc_gpu=input_batch.query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + max_query_len=max_query_len, + seq_lens=input_batch.seq_lens, + max_seq_len=self.max_model_len, + block_tables=block_tables, + slot_mappings=slot_mappings, + kv_cache_config=kv_cache_config, + dcp_local_seq_lens=input_batch.dcp_local_seq_lens, + ) + return attn_metadata diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 6cd13cebf..0c85bf65e 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -182,6 +182,8 @@ class EagleSpeculator: def propose( self, input_batch: InputBatch, + attn_metadata: dict[str, Any], + slot_mappings: dict[str, torch.Tensor], # [num_tokens, hidden_size] last_hidden_states: torch.Tensor, # num_layers x [num_tokens, hidden_size] @@ -229,8 +231,8 @@ class EagleSpeculator: # TODO(woosuk): Support CUDA graph for prefill. last_hidden_states, hidden_states = self.run_model( num_tokens, - input_batch.attn_metadata, - input_batch.slot_mappings, + attn_metadata, + slot_mappings, num_tokens_across_dp=None, # FIXME ) sample_hidden_states = last_hidden_states[last_token_indices]