Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
@@ -4,8 +4,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -98,48 +96,3 @@ class BlockTable:
|
||||
def get_numpy_array(self) -> np.ndarray:
|
||||
"""Returns the numpy array of the block table."""
|
||||
return self.block_table_np
|
||||
|
||||
|
||||
class MultiGroupBlockTable:
|
||||
"""The BlockTables for each KV cache group."""
|
||||
|
||||
def __init__(self, max_num_reqs: int, max_model_len: int,
|
||||
max_num_batched_tokens: int, pin_memory: bool,
|
||||
device: torch.device, kv_cache_config: KVCacheConfig) -> None:
|
||||
max_num_blocks_per_req = [
|
||||
cdiv(max_model_len, g.kv_cache_spec.block_size)
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
self.block_tables = [
|
||||
BlockTable(max_num_reqs, max_num_blocks_per_req[i],
|
||||
max_num_batched_tokens, pin_memory, device)
|
||||
for i in range(len(kv_cache_config.kv_cache_groups))
|
||||
]
|
||||
|
||||
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.append_row(block_ids[i], row_idx)
|
||||
|
||||
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.add_row(block_ids[i], row_idx)
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.move_row(src, tgt)
|
||||
|
||||
def swap_row(self, src: int, tgt: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.swap_row(src, tgt)
|
||||
|
||||
def commit(self, num_reqs: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.commit(num_reqs)
|
||||
|
||||
def clear(self) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.clear()
|
||||
|
||||
def __getitem__(self, idx: int) -> "BlockTable":
|
||||
"""Returns the BlockTable for the i-th KV cache group."""
|
||||
return self.block_tables[idx]
|
||||
|
||||
@@ -11,11 +11,10 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@@ -30,7 +29,7 @@ class CachedRequestState:
|
||||
sampling_params: SamplingParams
|
||||
generator: Optional[torch.Generator]
|
||||
|
||||
block_ids: list[list[int]]
|
||||
block_ids: list[int]
|
||||
num_computed_tokens: int
|
||||
output_token_ids: list[int]
|
||||
|
||||
@@ -59,14 +58,15 @@ class InputBatch:
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_blocks_per_req: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
@@ -99,13 +99,12 @@ class InputBatch:
|
||||
self.num_computed_tokens_cpu_tensor.numpy()
|
||||
|
||||
# Block table.
|
||||
self.block_table = MultiGroupBlockTable(
|
||||
self.block_table = BlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_num_blocks_per_req=max_num_blocks_per_req,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
|
||||
@@ -12,8 +12,6 @@ import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadataBuilder)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
@@ -34,8 +32,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||
is_pin_memory_available)
|
||||
GiB_bytes, LayerBlockType, LazyLoader, cdiv,
|
||||
check_use_alibi, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
@@ -53,7 +51,6 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
@@ -105,17 +102,59 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
cache_config.cache_dtype]
|
||||
|
||||
# NOTE(woosuk): sliding_window is None for models with interleaved
|
||||
# attention. Use interleaved_sliding_window instead.
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.interleaved_sliding_window = getattr(
|
||||
model_config.hf_text_config, "interleaved_sliding_window", None)
|
||||
self.window_size = (self.sliding_window
|
||||
or self.interleaved_sliding_window)
|
||||
|
||||
self.is_multimodal_model = model_config.is_multimodal_model
|
||||
self.block_size = cache_config.block_size
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
# Model-related.
|
||||
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
self.num_query_heads = model_config.get_num_attention_heads(
|
||||
parallel_config)
|
||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
self.attention_chunk_size = model_config.attention_chunk_size
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.head_size,
|
||||
self.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
)
|
||||
if self.attn_backend is None:
|
||||
error_msg = (
|
||||
f"Error with get_att_backend: {self.head_size=}, "
|
||||
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
|
||||
f"{self.model_config.is_attention_free=}, "
|
||||
f"{self.model_config.use_mla=}")
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 GPUModelRunner.")
|
||||
|
||||
if self.vllm_config.compilation_config.full_cuda_graph:
|
||||
attn_backend_name = self.attn_backend.__name__
|
||||
flash_attn_version = get_flash_attn_version()
|
||||
if attn_backend_name != "FlashAttentionBackend" or \
|
||||
flash_attn_version != 3:
|
||||
raise ValueError(
|
||||
f"full_cuda_graph is only supported with "
|
||||
f"FA3. Current attention backend is {attn_backend_name}, "
|
||||
f"FlashAttention version is {flash_attn_version}.")
|
||||
|
||||
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
||||
|
||||
# Multi-modal data support
|
||||
@@ -137,10 +176,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# self.model: nn.Module # Set after load_model
|
||||
# Initialize in initialize_kv_cache
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
||||
self.attn_backends: list[type[AttentionBackend]] = []
|
||||
# self.kv_cache_config: KVCacheConfig
|
||||
# self.input_batch: InputBatch # Persistent batch.
|
||||
# self.attn_metadata_builder: type[AttentionMetadataBuilder]
|
||||
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
@@ -169,6 +206,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
# Persistent batch.
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=model_config.get_vocab_size(),
|
||||
)
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
@@ -263,31 +310,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""
|
||||
Update the order of requests in the batch based on the attention
|
||||
backend's needs. For example, some attention backends (namely MLA) may
|
||||
want to separate requests based on if the attention computation will be
|
||||
compute-bound or memory-bound.
|
||||
|
||||
Args:
|
||||
scheduler_output: The scheduler output.
|
||||
|
||||
Returns:
|
||||
True if the batch was reordered, False otherwise.
|
||||
"""
|
||||
batch_reordered = self.attn_metadata_builders[0].reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
|
||||
# For models with multiple KV cache groups, the groups should agree on
|
||||
# the same order of requests. We ensure this by only allowing the first
|
||||
# group to reorder the batch and asserting that all other groups do not
|
||||
# reorder the batch.
|
||||
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
|
||||
assert not self.attn_metadata_builders[i].reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
return batch_reordered
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
output.
|
||||
@@ -424,8 +446,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Update the block IDs.
|
||||
if not req_data.resumed_from_preemption:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
for i in range(len(self.kv_cache_config.kv_cache_groups)):
|
||||
req_state.block_ids[i].extend(req_data.new_block_ids[i])
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
else:
|
||||
# The request is resumed from preemption.
|
||||
# Replace the existing block IDs with the new ones.
|
||||
@@ -483,7 +504,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
batch_reordered = self._may_reorder_batch(scheduler_output)
|
||||
# Some attention backends (namely MLA) may want to separate requests
|
||||
# based on if the attention computation will be compute-bound or
|
||||
# memory-bound. This gives them a hook to do that.
|
||||
batch_reordered = self.attn_metadata_builder.reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
|
||||
if batch_changed or batch_reordered:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
@@ -551,29 +576,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
torch.from_numpy(token_indices),
|
||||
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
||||
|
||||
# Calculate the slot mapping for each KV cache group.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
block_size = kv_cache_group_spec.kv_cache_spec.block_size
|
||||
block_table: BlockTable = self.input_batch.block_table[
|
||||
kv_cache_group_id]
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
block_table_indices = (
|
||||
req_indices * block_table.max_num_blocks_per_req +
|
||||
positions_np // block_size)
|
||||
block_table_cpu = block_table.get_cpu_tensor()
|
||||
block_numbers = block_table_cpu.flatten(
|
||||
)[block_table_indices].numpy()
|
||||
block_offsets = positions_np % block_size
|
||||
np.add(
|
||||
block_numbers * block_size,
|
||||
block_offsets,
|
||||
out=block_table.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
# Calculate the slot mapping.
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
|
||||
# because M (max_model_len) is not necessarily divisible by block_size.
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions_np // self.block_size)
|
||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
|
||||
block_offsets = positions_np % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.input_batch.block_table.
|
||||
slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc_np[0] = 0
|
||||
@@ -615,6 +632,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_metadata: dict[str, FlashAttentionMetadata] = {}
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
# NOTE(Chen): there is exactly one KV cache group that contains all
|
||||
# attetnion layers in the model for now, so the current logic for
|
||||
# getting attn_metadata is not related to kv_cache_group information.
|
||||
# Will extend this part to support multiple KV cache groups later.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
|
||||
@@ -623,19 +644,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.
|
||||
num_common_prefix_blocks[kv_cache_group_id],
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
self.attn_metadata_builders[kv_cache_group_id],
|
||||
scheduler_output.num_common_prefix_blocks,
|
||||
)
|
||||
|
||||
attn_metadata_i = (
|
||||
self.attn_metadata_builders[kv_cache_group_id].build(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata))
|
||||
attn_metadata_i = self.attn_metadata_builder.build(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
@@ -673,8 +690,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_common_prefix_blocks: int,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
attn_metadata_builder: AttentionMetadataBuilder,
|
||||
) -> int:
|
||||
"""Compute the length of the common prefix for cascade attention.
|
||||
|
||||
@@ -693,7 +708,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
Returns:
|
||||
int: Length of common prefix in tokens.
|
||||
"""
|
||||
common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
|
||||
common_prefix_len = num_common_prefix_blocks * self.block_size
|
||||
if common_prefix_len == 0:
|
||||
# Common case.
|
||||
return 0
|
||||
@@ -742,19 +757,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
common_prefix_len,
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
|
||||
# common_prefix_len should be a multiple of the block size.
|
||||
common_prefix_len = (common_prefix_len // kv_cache_spec.block_size *
|
||||
kv_cache_spec.block_size)
|
||||
use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or
|
||||
(isinstance(kv_cache_spec, FullAttentionSpec)
|
||||
and kv_cache_spec.sliding_window is not None))
|
||||
assert isinstance(kv_cache_spec, AttentionSpec)
|
||||
use_cascade = attn_metadata_builder.use_cascade_attention(
|
||||
common_prefix_len = (common_prefix_len // self.block_size *
|
||||
self.block_size)
|
||||
use_cascade = self.attn_metadata_builder.use_cascade_attention(
|
||||
common_prefix_len=common_prefix_len,
|
||||
query_lens=num_scheduled_tokens,
|
||||
num_query_heads=self.num_query_heads,
|
||||
num_kv_heads=kv_cache_spec.num_kv_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
use_alibi=self.use_alibi,
|
||||
use_sliding_window=use_sliding_window,
|
||||
use_sliding_window=self.window_size is not None,
|
||||
num_sms=self.num_sms,
|
||||
)
|
||||
return common_prefix_len if use_cascade else 0
|
||||
@@ -1640,7 +1651,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=np.int32)
|
||||
|
||||
if skip_attn:
|
||||
attn_metadata: Optional[dict[str, FlashAttentionMetadata]] = None
|
||||
attn_metadata = None
|
||||
else:
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
@@ -1648,19 +1659,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
|
||||
attn_metadata = {}
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
attn_metadata_i = (
|
||||
self.attn_metadata_builders[kv_cache_group_id].build(
|
||||
num_reqs=num_tokens,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=num_tokens,
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
num_reqs=num_tokens,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=num_tokens,
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
@@ -1890,56 +1895,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time, cuda_graph_size / (1 << 30))
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize the attention backends and attention metadata builders.
|
||||
"""
|
||||
assert len(self.attn_backends) == 0 and len(
|
||||
self.attn_metadata_builders
|
||||
) == 0, "Attention backends are already initialized"
|
||||
for i, kv_cache_group_spec in enumerate(
|
||||
kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
if not isinstance(kv_cache_spec, AttentionSpec):
|
||||
raise NotImplementedError(
|
||||
"Only AttentionSpec is supported for now.")
|
||||
attn_backend_i = get_attn_backend(
|
||||
kv_cache_spec.head_size,
|
||||
self.dtype,
|
||||
kv_cache_spec.dtype,
|
||||
kv_cache_spec.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=kv_cache_spec.use_mla,
|
||||
)
|
||||
if attn_backend_i is None:
|
||||
error_msg = (
|
||||
f"Error with get_attn_backend: {kv_cache_spec.head_size=}, "
|
||||
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
|
||||
f"{kv_cache_spec.block_size=}, "
|
||||
f"{self.model_config.is_attention_free=}, "
|
||||
f"{kv_cache_spec.use_mla=}")
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 "
|
||||
"GPUModelRunner.")
|
||||
|
||||
if self.vllm_config.compilation_config.full_cuda_graph:
|
||||
attn_backend_name = attn_backend_i.__name__
|
||||
flash_attn_version = get_flash_attn_version()
|
||||
if attn_backend_name != "FlashAttentionBackend" or \
|
||||
flash_attn_version != 3:
|
||||
raise ValueError(
|
||||
f"full_cuda_graph is only supported with "
|
||||
f"FA3. Current attention backend is "
|
||||
f"{attn_backend_name}, FlashAttention version is "
|
||||
f"{flash_attn_version}.")
|
||||
|
||||
block_table_i = self.input_batch.block_table[i]
|
||||
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
|
||||
weakref.proxy(self), kv_cache_spec, block_table_i)
|
||||
self.attn_backends.append(attn_backend_i)
|
||||
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
@@ -1947,21 +1902,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
if len(kv_cache_config.kv_cache_groups) > 1:
|
||||
raise NotImplementedError(
|
||||
"Hybrid models with more than one KV cache type are not "
|
||||
"supported yet.")
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
tensor_config = kv_cache_config.tensors[layer_name]
|
||||
@@ -1976,7 +1925,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# the min of all `num_blocks`. Verify it here.
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
@@ -1996,6 +1945,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self),
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec,
|
||||
self.input_batch.block_table)
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
|
||||
@@ -171,10 +171,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
# self.input_batch: InputBatch # Persistent batch.
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
# Persistent batch.
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
# Cached torch/numpy tensor
|
||||
# The pytorch tensor and numpy array share the same buffer.
|
||||
@@ -190,7 +199,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(self.max_num_reqs, self.max_num_blocks_per_req),
|
||||
dtype=torch.int32,
|
||||
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
||||
device="cpu")
|
||||
|
||||
self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
|
||||
@@ -515,12 +524,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||
# because torch.index_select is much faster than np.take for large
|
||||
# tensors.
|
||||
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
|
||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
|
||||
block_offsets = positions_np % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.input_batch.block_table[0].
|
||||
out=self.input_batch.block_table.
|
||||
slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata.
|
||||
@@ -545,15 +554,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.position_ids = self.positions_cpu[:
|
||||
padded_total_num_scheduled_tokens].to(
|
||||
self.device)
|
||||
self.input_batch.block_table[0].slot_mapping_cpu[
|
||||
self.input_batch.block_table.slot_mapping_cpu[
|
||||
total_num_scheduled_tokens:] = _PAD_SLOT_ID
|
||||
slot_mapping = (
|
||||
self.input_batch.block_table[0].
|
||||
self.input_batch.block_table.
|
||||
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
|
||||
self.device))
|
||||
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
||||
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
|
||||
self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
|
||||
block_tables = block_tables.to(self.device)
|
||||
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
|
||||
self.device)
|
||||
@@ -1254,18 +1263,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
"Hybrid models with more than one KV cache type are not "
|
||||
"supported yet.")
|
||||
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
assert self.block_table_cpu.dtype == self.input_batch.block_table[
|
||||
0].get_cpu_tensor().dtype
|
||||
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
||||
|
||||
Reference in New Issue
Block a user