[Model Runner V2] Prepare attn metadata in ModelState [2/N] (#35383)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-02-26 11:47:02 -08:00
committed by GitHub
parent c66aa48e99
commit 3d66502e1b
4 changed files with 110 additions and 92 deletions

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
@@ -60,6 +59,8 @@ class InputBatch:
query_start_loc_np: np.ndarray
# [num_reqs]
seq_lens: torch.Tensor
# [num_reqs]
dcp_local_seq_lens: torch.Tensor | None
# [num_tokens_after_padding]
input_ids: torch.Tensor
@@ -68,11 +69,6 @@ class InputBatch:
# [num_tokens_after_padding, hidden_size]
inputs_embeds: torch.Tensor | None
# layer_name -> Metadata
attn_metadata: dict[str, Any]
# layer_name -> slot_mapping
slot_mappings: dict[str, torch.Tensor]
# [total_num_logits]
logits_indices: torch.Tensor
# [num_reqs + 1]
@@ -139,11 +135,10 @@ class InputBatch:
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
dcp_local_seq_lens=None,
input_ids=input_ids,
positions=positions,
inputs_embeds=None,
attn_metadata=None, # type: ignore
slot_mappings=None, # type: ignore
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,

View File

@@ -46,7 +46,6 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
get_kv_cache_spec,
init_attn_backend,
@@ -317,31 +316,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
max_query_len=input_batch.num_scheduled_tokens.max().item(),
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
)
input_batch.attn_metadata = attn_metadata
input_batch.slot_mappings = slot_mappings_by_layer
@torch.inference_mode()
def _dummy_run(
self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
@@ -384,7 +358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return None, None
assert self.execute_model_state is not None
hidden_states, _, input_batch, _ = self.execute_model_state
input_batch, _, _, _, hidden_states, _, _ = self.execute_model_state
self.execute_model_state = None
assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices]
@@ -546,7 +520,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.encoder_runner.add_request(req_id, new_req_data.mm_features)
self.model_state.add_request(req_index, new_req_data)
self.block_tables.append_block_ids(
req_index, new_req_data.block_ids, overwrite=True
)
@@ -624,9 +597,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
idx_mapping, total_num_logits, cu_num_logits, max_expand_len
)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
# Get query_start_loc.
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
query_start_loc_np[0] = 0
@@ -635,11 +605,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np[num_reqs + 1 :] = num_tokens
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
max_query_len = num_scheduled_tokens.max().item()
# Get prefill tokens if any.
if self.req_states.any_prefills(idx_mapping_np):
@@ -663,6 +630,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
seq_lens = self.input_buffers.seq_lens[:num_reqs]
dcp_local_seq_lens = None
if self.use_dcp:
# Prepare dcp local seq_lens.
prepare_dcp_local_seq_lens(
@@ -673,7 +641,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.dcp_rank,
self.cp_interleave,
)
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
@@ -689,35 +657,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
total_num_logits,
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping,
query_start_loc,
self.input_buffers.positions[:num_tokens],
)
# Layer name -> slot mapping.
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
# Layer name -> attention metadata.
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=self.input_buffers.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
dcp_local_seq_lens=dcp_local_seq_lens,
)
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
positions = self.input_buffers.positions[:num_tokens_after_padding]
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
@@ -732,17 +671,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
input_ids=input_ids,
positions=positions,
dcp_local_seq_lens=dcp_local_seq_lens,
input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
positions=self.input_buffers.positions[:num_tokens_after_padding],
inputs_embeds=None,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
)
def prepare_attn(
self, input_batch: InputBatch
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.idx_mapping,
input_batch.query_start_loc,
input_batch.positions,
)
return block_tables, slot_mappings
def prepare_dummy_attn(
self, input_batch: InputBatch
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens
)
return block_tables, slot_mappings
@torch.inference_mode()
def get_mm_embeddings(
self,
@@ -899,6 +859,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch = self.prepare_inputs(
scheduler_output, num_tokens_after_padding
)
block_tables, slot_mappings = self.prepare_attn(input_batch)
if self.lora_config:
# Activate LoRA adapters.
lora_inputs = self.lora_state.make_lora_inputs(
@@ -929,9 +891,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device=self.device,
)
if not skip_attn_for_dummy_run:
self.prepare_dummy_attn_metadata(input_batch)
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
else:
block_tables = None
slot_mappings = None
# FIXME(woosuk): Fix warmup for LoRA.
attn_metadata = None
slot_mappings_by_layer = None
if not (dummy_run and skip_attn_for_dummy_run):
assert slot_mappings is not None
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
assert block_tables is not None
attn_metadata = self.model_state.prepare_attn(
input_batch,
block_tables,
slot_mappings,
self.attn_groups,
self.kv_cache_config,
)
model_inputs = {
"input_ids": input_batch.input_ids,
"positions": input_batch.positions,
@@ -968,13 +949,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
with set_forward_context(
input_batch.attn_metadata,
attn_metadata,
self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding,
cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=input_batch.slot_mappings,
slot_mapping=slot_mappings_by_layer,
):
self.kv_connector.pre_forward(scheduler_output)
model_output = self.model(**model_inputs)
@@ -985,22 +966,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states = None
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = (
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
)
if not self.is_last_pp_rank:
# Non-last PP rank: return IntermediateTensors for sending.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
self.execute_model_state = (None, None, input_batch, kv_connector_output)
return hidden_states
# Last rank (or no PP): hidden_states is a tensor for sampling.
assert isinstance(hidden_states, torch.Tensor)
self.execute_model_state = (
hidden_states,
aux_hidden_states,
input_batch,
kv_connector_output,
)
return None
@torch.inference_mode()
@@ -1010,9 +992,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.execute_model_state is None:
# The prior execute_model call must have failed.
return None
hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
self.execute_model_state
)
(
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
) = self.execute_model_state
self.execute_model_state = None
if not self.is_last_pp_rank:
@@ -1075,6 +1063,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculator is not None:
draft_tokens = self.speculator.propose(
input_batch,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
num_sampled,

View File

@@ -1,13 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import AttentionGroup
class ModelState:
@@ -72,3 +77,29 @@ class ModelState:
return {}
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
return {"positions": mrope_positions}
def prepare_attn(
self,
input_batch: InputBatch,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
)
return attn_metadata

View File

@@ -182,6 +182,8 @@ class EagleSpeculator:
def propose(
self,
input_batch: InputBatch,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
# [num_tokens, hidden_size]
last_hidden_states: torch.Tensor,
# num_layers x [num_tokens, hidden_size]
@@ -229,8 +231,8 @@ class EagleSpeculator:
# TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states, hidden_states = self.run_model(
num_tokens,
input_batch.attn_metadata,
input_batch.slot_mappings,
attn_metadata,
slot_mappings,
num_tokens_across_dp=None, # FIXME
)
sample_hidden_states = last_hidden_states[last_token_indices]