[Model Runner V2] Prepare attn metadata in ModelState [2/N] (#35383)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -60,6 +59,8 @@ class InputBatch:
|
||||
query_start_loc_np: np.ndarray
|
||||
# [num_reqs]
|
||||
seq_lens: torch.Tensor
|
||||
# [num_reqs]
|
||||
dcp_local_seq_lens: torch.Tensor | None
|
||||
|
||||
# [num_tokens_after_padding]
|
||||
input_ids: torch.Tensor
|
||||
@@ -68,11 +69,6 @@ class InputBatch:
|
||||
# [num_tokens_after_padding, hidden_size]
|
||||
inputs_embeds: torch.Tensor | None
|
||||
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
# layer_name -> slot_mapping
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
|
||||
# [total_num_logits]
|
||||
logits_indices: torch.Tensor
|
||||
# [num_reqs + 1]
|
||||
@@ -139,11 +135,10 @@ class InputBatch:
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
dcp_local_seq_lens=None,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=None, # type: ignore
|
||||
slot_mappings=None, # type: ignore
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
|
||||
@@ -46,7 +46,6 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
get_kv_cache_spec,
|
||||
init_attn_backend,
|
||||
@@ -317,31 +316,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
|
||||
|
||||
def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
|
||||
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
|
||||
slot_mappings = self.block_tables.get_dummy_slot_mappings(
|
||||
input_batch.num_tokens
|
||||
)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=input_batch.num_reqs,
|
||||
num_tokens=input_batch.num_tokens,
|
||||
query_start_loc_gpu=input_batch.query_start_loc,
|
||||
query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
|
||||
max_query_len=input_batch.num_scheduled_tokens.max().item(),
|
||||
seq_lens=input_batch.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
|
||||
)
|
||||
input_batch.attn_metadata = attn_metadata
|
||||
input_batch.slot_mappings = slot_mappings_by_layer
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
|
||||
@@ -384,7 +358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return None, None
|
||||
|
||||
assert self.execute_model_state is not None
|
||||
hidden_states, _, input_batch, _ = self.execute_model_state
|
||||
input_batch, _, _, _, hidden_states, _, _ = self.execute_model_state
|
||||
self.execute_model_state = None
|
||||
assert hidden_states is not None # Last PP rank always has hidden_states
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
@@ -546,7 +520,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.encoder_runner.add_request(req_id, new_req_data.mm_features)
|
||||
|
||||
self.model_state.add_request(req_index, new_req_data)
|
||||
|
||||
self.block_tables.append_block_ids(
|
||||
req_index, new_req_data.block_ids, overwrite=True
|
||||
)
|
||||
@@ -624,9 +597,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
idx_mapping, total_num_logits, cu_num_logits, max_expand_len
|
||||
)
|
||||
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
||||
|
||||
# Get query_start_loc.
|
||||
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
|
||||
query_start_loc_np[0] = 0
|
||||
@@ -635,11 +605,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
|
||||
query_start_loc_np[num_reqs + 1 :] = num_tokens
|
||||
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
|
||||
|
||||
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
max_query_len = num_scheduled_tokens.max().item()
|
||||
|
||||
# Get prefill tokens if any.
|
||||
if self.req_states.any_prefills(idx_mapping_np):
|
||||
@@ -663,6 +630,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
seq_lens = self.input_buffers.seq_lens[:num_reqs]
|
||||
|
||||
dcp_local_seq_lens = None
|
||||
if self.use_dcp:
|
||||
# Prepare dcp local seq_lens.
|
||||
prepare_dcp_local_seq_lens(
|
||||
@@ -673,7 +641,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.dcp_rank,
|
||||
self.cp_interleave,
|
||||
)
|
||||
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens
|
||||
# and draft tokens. Also, get the logits indices to sample tokens from.
|
||||
@@ -689,35 +657,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
total_num_logits,
|
||||
)
|
||||
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
self.input_buffers.positions[:num_tokens],
|
||||
)
|
||||
# Layer name -> slot mapping.
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
|
||||
# Layer name -> attention metadata.
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=max_query_len,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
||||
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
@@ -732,17 +671,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
|
||||
positions=self.input_buffers.positions[:num_tokens_after_padding],
|
||||
inputs_embeds=None,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings_by_layer,
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
|
||||
)
|
||||
|
||||
def prepare_attn(
|
||||
self, input_batch: InputBatch
|
||||
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping)
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
input_batch.idx_mapping,
|
||||
input_batch.query_start_loc,
|
||||
input_batch.positions,
|
||||
)
|
||||
return block_tables, slot_mappings
|
||||
|
||||
def prepare_dummy_attn(
|
||||
self, input_batch: InputBatch
|
||||
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
|
||||
slot_mappings = self.block_tables.get_dummy_slot_mappings(
|
||||
input_batch.num_tokens
|
||||
)
|
||||
return block_tables, slot_mappings
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_mm_embeddings(
|
||||
self,
|
||||
@@ -899,6 +859,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_batch = self.prepare_inputs(
|
||||
scheduler_output, num_tokens_after_padding
|
||||
)
|
||||
block_tables, slot_mappings = self.prepare_attn(input_batch)
|
||||
|
||||
if self.lora_config:
|
||||
# Activate LoRA adapters.
|
||||
lora_inputs = self.lora_state.make_lora_inputs(
|
||||
@@ -929,9 +891,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
device=self.device,
|
||||
)
|
||||
if not skip_attn_for_dummy_run:
|
||||
self.prepare_dummy_attn_metadata(input_batch)
|
||||
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
|
||||
else:
|
||||
block_tables = None
|
||||
slot_mappings = None
|
||||
# FIXME(woosuk): Fix warmup for LoRA.
|
||||
|
||||
attn_metadata = None
|
||||
slot_mappings_by_layer = None
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
assert slot_mappings is not None
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
assert block_tables is not None
|
||||
attn_metadata = self.model_state.prepare_attn(
|
||||
input_batch,
|
||||
block_tables,
|
||||
slot_mappings,
|
||||
self.attn_groups,
|
||||
self.kv_cache_config,
|
||||
)
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": input_batch.input_ids,
|
||||
"positions": input_batch.positions,
|
||||
@@ -968,13 +949,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=input_batch.slot_mappings,
|
||||
slot_mapping=slot_mappings_by_layer,
|
||||
):
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
model_output = self.model(**model_inputs)
|
||||
@@ -985,22 +966,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
aux_hidden_states = None
|
||||
|
||||
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
|
||||
self.execute_model_state = (
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
)
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
# Non-last PP rank: return IntermediateTensors for sending.
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
hidden_states.kv_connector_output = kv_connector_output
|
||||
self.execute_model_state = (None, None, input_batch, kv_connector_output)
|
||||
return hidden_states
|
||||
|
||||
# Last rank (or no PP): hidden_states is a tensor for sampling.
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
self.execute_model_state = (
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
input_batch,
|
||||
kv_connector_output,
|
||||
)
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -1010,9 +992,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.execute_model_state is None:
|
||||
# The prior execute_model call must have failed.
|
||||
return None
|
||||
hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
|
||||
self.execute_model_state
|
||||
)
|
||||
(
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
) = self.execute_model_state
|
||||
self.execute_model_state = None
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
@@ -1075,6 +1063,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.speculator is not None:
|
||||
draft_tokens = self.speculator.propose(
|
||||
input_batch,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
num_sampled,
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class ModelState:
|
||||
@@ -72,3 +77,29 @@ class ModelState:
|
||||
return {}
|
||||
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
|
||||
return {"positions": mrope_positions}
|
||||
|
||||
def prepare_attn(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
block_tables: tuple[torch.Tensor, ...],
|
||||
slot_mappings: torch.Tensor,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
|
||||
max_query_len = input_batch.num_scheduled_tokens.max().item()
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=input_batch.num_reqs,
|
||||
num_tokens=input_batch.num_tokens,
|
||||
query_start_loc_gpu=input_batch.query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=max_query_len,
|
||||
seq_lens=input_batch.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@@ -182,6 +182,8 @@ class EagleSpeculator:
|
||||
def propose(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
# [num_tokens, hidden_size]
|
||||
last_hidden_states: torch.Tensor,
|
||||
# num_layers x [num_tokens, hidden_size]
|
||||
@@ -229,8 +231,8 @@ class EagleSpeculator:
|
||||
# TODO(woosuk): Support CUDA graph for prefill.
|
||||
last_hidden_states, hidden_states = self.run_model(
|
||||
num_tokens,
|
||||
input_batch.attn_metadata,
|
||||
input_batch.slot_mappings,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp=None, # FIXME
|
||||
)
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
|
||||
Reference in New Issue
Block a user