[Model Runner V2] Skip building deprecated fields in attn metadata (#32132)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
@@ -147,16 +146,13 @@ def build_attn_metadata(
|
||||
query_start_loc_gpu: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_np: np.ndarray,
|
||||
num_computed_tokens_cpu: torch.Tensor | None,
|
||||
max_seq_len: int,
|
||||
block_tables: Sequence[torch.Tensor],
|
||||
slot_mappings: torch.Tensor,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
max_query_len = int(query_start_loc_cpu.max())
|
||||
seq_lens = seq_lens[:num_reqs]
|
||||
seq_lens_cpu = torch.from_numpy(seq_lens_np)
|
||||
max_seq_len = int(seq_lens_np.max())
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||
@@ -168,9 +164,7 @@ def build_attn_metadata(
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
max_seq_len=max_seq_len,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
|
||||
@@ -232,11 +232,9 @@ def prepare_inputs_to_capture(
|
||||
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) * num_tokens_per_req
|
||||
query_start_loc.np[num_reqs:] = num_tokens
|
||||
query_start_loc.copy_to_gpu()
|
||||
seq_lens_np = np.full(num_reqs, max_model_len, dtype=np.int32)
|
||||
|
||||
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
|
||||
# rather than max_model_len. This introduces a discrepancy between
|
||||
# seq_lens (on GPU) and seq_lens_np (on CPU), which may cause issues for
|
||||
# certain attention backends.
|
||||
# rather than max_model_len.
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
|
||||
@@ -250,8 +248,7 @@ def prepare_inputs_to_capture(
|
||||
query_start_loc_gpu=query_start_loc.gpu[: num_reqs + 1],
|
||||
query_start_loc_cpu=query_start_loc.cpu[: num_reqs + 1],
|
||||
seq_lens=input_buffers.seq_lens,
|
||||
seq_lens_np=seq_lens_np,
|
||||
num_computed_tokens_cpu=None, # FIXME
|
||||
max_seq_len=max_model_len,
|
||||
block_tables=input_block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
|
||||
@@ -70,7 +70,6 @@ class InputBatch:
|
||||
query_start_loc_np: np.ndarray
|
||||
# [num_reqs]
|
||||
seq_lens: torch.Tensor
|
||||
seq_lens_np: np.ndarray
|
||||
|
||||
# [num_tokens_after_padding]
|
||||
input_ids: torch.Tensor
|
||||
@@ -109,8 +108,6 @@ class InputBatch:
|
||||
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
|
||||
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
|
||||
# seq_len equals to query_len
|
||||
seq_lens_np = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
|
||||
seq_lens_np[-1] += num_tokens % num_reqs
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
|
||||
input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
@@ -133,7 +130,6 @@ class InputBatch:
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_np=seq_lens_np,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
attn_metadata=None, # type: ignore
|
||||
|
||||
@@ -228,9 +228,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
slot_mappings = self.block_tables.get_dummy_slot_mappings(
|
||||
input_batch.num_tokens
|
||||
)
|
||||
num_computed_tokens = torch.zeros(
|
||||
input_batch.num_reqs, dtype=torch.int32, device=self.device
|
||||
)
|
||||
query_start_loc = self.input_buffers.query_start_loc
|
||||
query_start_loc_gpu = query_start_loc.gpu[: input_batch.num_reqs + 1]
|
||||
query_start_loc_cpu = query_start_loc.cpu[: input_batch.num_reqs + 1]
|
||||
@@ -241,8 +238,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
query_start_loc_gpu=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
seq_lens_np=input_batch.seq_lens_np,
|
||||
num_computed_tokens_cpu=num_computed_tokens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
@@ -522,16 +518,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
query_start_loc_gpu, self.input_buffers.positions[:num_tokens]
|
||||
)
|
||||
|
||||
# Get num_computed_tokens.
|
||||
# HACK(woosuk): Here, we use num_computed_tokens on GPU instead of
|
||||
# num_computed_tokens_cpu. This works for most cases.
|
||||
num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping]
|
||||
# HACK(woosuk): Only GPU has the exact seq_lens because at this point
|
||||
# CPU does not know how many draft tokens are accepted/rejected in the
|
||||
# previous step. Therefore, we use max_model_len to be safe.
|
||||
# NOTE(woosuk): This only works for FA3 backend.
|
||||
seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)
|
||||
|
||||
# Layer name -> attention metadata.
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
@@ -540,8 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
query_start_loc_gpu=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
seq_lens_np=seq_lens_np,
|
||||
num_computed_tokens_cpu=num_computed_tokens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
@@ -561,7 +546,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_np=seq_lens_np,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
|
||||
@@ -288,8 +288,6 @@ class EagleSpeculator:
|
||||
# Run eager mode.
|
||||
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1)
|
||||
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
|
||||
# HACK(woosuk)
|
||||
seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)
|
||||
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
@@ -300,8 +298,7 @@ class EagleSpeculator:
|
||||
query_start_loc_gpu=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||
seq_lens_np=seq_lens_np,
|
||||
num_computed_tokens_cpu=None, # FIXME
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
|
||||
Reference in New Issue
Block a user