[Model Runner V2] Use ModelState.prepare_attn() for cuda graph capture [5/N] (#35774)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
@@ -15,13 +14,11 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
)
|
||||
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
@@ -123,14 +120,11 @@ class CudaGraphManager:
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=(
|
||||
self.uniform_decode_query_len if uniform_decode else 0
|
||||
),
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
@@ -393,51 +387,36 @@ def capture_graphs(
|
||||
def prepare_inputs_to_capture(
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
max_model_len: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
uniform_decode_query_len: int = 0,
|
||||
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
||||
if uniform_decode_query_len > 0:
|
||||
num_tokens_per_req = uniform_decode_query_len
|
||||
else:
|
||||
num_tokens_per_req = num_tokens // num_reqs
|
||||
|
||||
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
|
||||
query_start_loc_np[-1] = num_tokens
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
|
||||
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
|
||||
|
||||
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
|
||||
# rather than max_model_len.
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
|
||||
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
|
||||
|
||||
input_batch = InputBatch.make_dummy(num_reqs, num_tokens, input_buffers)
|
||||
input_block_tables = block_tables.get_dummy_block_tables(num_reqs)
|
||||
slot_mappings = block_tables.get_dummy_slot_mappings(num_tokens)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, kv_cache_config
|
||||
)
|
||||
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=num_tokens_per_req,
|
||||
seq_lens=input_buffers.seq_lens,
|
||||
max_seq_len=max_model_len,
|
||||
block_tables=input_block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
|
||||
# HACK(woosuk): Special handling for DCP.
|
||||
if block_tables.cp_size > 1:
|
||||
prepare_dcp_local_seq_lens(
|
||||
input_buffers.dcp_local_seq_lens,
|
||||
input_batch.seq_lens,
|
||||
num_reqs,
|
||||
block_tables.cp_size,
|
||||
block_tables.cp_rank,
|
||||
block_tables.cp_interleave,
|
||||
)
|
||||
input_batch.dcp_local_seq_lens = input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
attn_metadata = model_state.prepare_attn(
|
||||
input_batch,
|
||||
input_block_tables,
|
||||
slot_mappings,
|
||||
attn_groups,
|
||||
kv_cache_config,
|
||||
)
|
||||
return attn_metadata, slot_mappings_by_layer
|
||||
|
||||
@@ -82,14 +82,16 @@ class InputBatch:
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
input_buffers: InputBuffers,
|
||||
device: torch.device,
|
||||
) -> "InputBatch":
|
||||
assert 0 < num_reqs <= num_tokens
|
||||
device = input_buffers.device
|
||||
|
||||
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
|
||||
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
expanded_idx_mapping = idx_mapping
|
||||
expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device)
|
||||
|
||||
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
|
||||
num_scheduled_tokens[-1] += num_tokens % num_reqs
|
||||
assert int(num_scheduled_tokens.sum()) == num_tokens
|
||||
@@ -115,7 +117,6 @@ class InputBatch:
|
||||
input_ids = input_buffers.input_ids[:num_tokens].zero_()
|
||||
positions = input_buffers.positions[:num_tokens].zero_()
|
||||
|
||||
# attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
|
||||
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
|
||||
|
||||
@@ -311,6 +311,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.speculator is not None:
|
||||
# HACK(woosuk)
|
||||
self.speculator.set_attn(
|
||||
self.model_state,
|
||||
self.kv_cache_config,
|
||||
self.attn_groups,
|
||||
self.block_tables,
|
||||
@@ -880,10 +881,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# No actual tokens to run. A dummy run for DP or memory profiling.
|
||||
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens_after_padding,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
num_reqs, num_tokens_after_padding, self.input_buffers
|
||||
)
|
||||
if not skip_attn_for_dummy_run:
|
||||
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import (
|
||||
)
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
@@ -59,6 +60,7 @@ class EagleCudaGraphManager:
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
generate_fn: Callable,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
@@ -76,12 +78,11 @@ class EagleCudaGraphManager:
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=1,
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
@@ -158,6 +159,7 @@ class EagleCudaGraphManager:
|
||||
def capture(
|
||||
self,
|
||||
generate_fn: Callable,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
@@ -173,6 +175,7 @@ class EagleCudaGraphManager:
|
||||
capture_cudagraph_mode=self.cudagraph_mode,
|
||||
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
|
||||
generate_fn=generate_fn,
|
||||
model_state=model_state,
|
||||
input_buffers=input_buffers,
|
||||
block_tables=block_tables,
|
||||
attn_groups=attn_groups,
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
|
||||
@@ -76,10 +77,12 @@ class EagleSpeculator:
|
||||
|
||||
def set_attn(
|
||||
self,
|
||||
model_state: ModelState,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
block_tables: BlockTables,
|
||||
) -> None:
|
||||
self.model_state = model_state
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.attn_groups = attn_groups
|
||||
self.block_tables = block_tables
|
||||
@@ -171,6 +174,7 @@ class EagleSpeculator:
|
||||
logger.info("Capturing model for Eagle speculator...")
|
||||
self.cudagraph_manager.capture(
|
||||
self.generate_draft,
|
||||
self.model_state,
|
||||
self.input_buffers,
|
||||
self.block_tables,
|
||||
self.attn_groups,
|
||||
|
||||
Reference in New Issue
Block a user