diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index f197cbb7b..ac08b9052 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -17,7 +17,7 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ _Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1, - _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1 + _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN ] # Remove flashinfer from the list if it's not available diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index be6cfce6f..78a650998 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -109,11 +109,11 @@ def create_common_attn_metadata( def get_attention_backend(backend_name: _Backend): """Set up attention backend classes for testing. - + Args: backend_name: Name of the backend ("flash_attn", "flashinfer", etc.) vllm_config: VllmConfig instance - + Returns: Tuple of (backend_builder_class, backend_impl_class) """ @@ -126,6 +126,8 @@ def get_attention_backend(backend_name: _Backend): "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", _Backend.TRITON_ATTN_VLLM_V1: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", + _Backend.TREE_ATTN: + "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", } if backend_name not in backend_map: diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index a126c7c94..05f6dd40a 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -202,7 +202,9 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) -def test_propose(num_speculative_tokens): +@pytest.mark.parametrize("backend", + [_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN]) +def test_propose(num_speculative_tokens, backend): # Use GPU device device = torch.device(current_platform.device_type) @@ -301,8 +303,7 @@ def test_propose(num_speculative_tokens): device=device) sampling_metadata = mock.MagicMock() - attn_metadata_builder_cls, _ = get_attention_backend( - _Backend.FLASH_ATTN_VLLM_V1) + attn_metadata_builder_cls, _ = get_attention_backend(backend) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), layer_names=proposer.attn_layer_names, diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py new file mode 100644 index 000000000..42468daa6 --- /dev/null +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -0,0 +1,299 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import Optional + +import torch + +from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec, + create_vllm_config, + get_attention_backend) +from vllm.config import ParallelConfig, SpeculativeConfig +from vllm.v1.attention.backends.utils import CommonAttentionMetadata + + +class MockAttentionLayer(torch.nn.Module): + _q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") + _k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") + _v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") + + def __init__(self): + super().__init__() + + def forward(self, x): + return x + + +def forward_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_cache: torch.Tensor, + block_table: torch.Tensor, + slot_mapping: torch.Tensor, + seqlen_k: int, + backend: _Backend, + spec_token_tree: Optional[str] = None, + num_spec_tokens: int = 0, +) -> torch.Tensor: + batch_size, q_len, num_heads, dim_per_head = q.shape + num_kv_heads = k.shape[-2] + # Initialize the query and KV sequence lengths. + query_start_loc = q_len * torch.arange( + batch_size + 1, device=q.device, dtype=torch.int32) + query_lens = torch.diff(query_start_loc) + seq_lens = torch.full( + (batch_size, ), + seqlen_k, + device=q.device, + dtype=torch.int32, + ) + context_lens = seq_lens - query_lens + max_query_len = q_len + num_actual_tokens = query_start_loc[-1] + + softmax_scale = q.shape[-1]**(-0.5) + layer = MockAttentionLayer() + + # Build common metadata. + model_name = "meta-llama/Meta-Llama-3-8B" + builder_cls, impl_cls = get_attention_backend(backend) + vllm_config = create_vllm_config(model_name=model_name, + max_model_len=max(seq_lens)) + if spec_token_tree is not None: + # Create speculative config if token tree is specified. + vllm_config.speculative_config = SpeculativeConfig( + target_model_config=vllm_config.model_config, + target_parallel_config=ParallelConfig(), + model=model_name, + method="eagle", + num_speculative_tokens=num_spec_tokens, + speculative_token_tree=spec_token_tree) + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + builder = builder_cls(kv_cache_spec, [], vllm_config, q.device) + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc.cpu(), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.cpu(), + num_computed_tokens_cpu=context_lens.cpu(), + num_reqs=batch_size, + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + block_table_tensor=block_table, + slot_mapping=slot_mapping, + ) + + # Build attention metadata. + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + # Initialize the backend implementation. + instance = impl_cls( + num_heads=num_heads, + head_size=dim_per_head, + scale=softmax_scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + ) + + # Run forward pass and return output. + query = q.view(-1, num_heads, dim_per_head) + key = k.view(-1, num_kv_heads, dim_per_head) + value = v.view(-1, num_kv_heads, dim_per_head) + output = torch.empty_like(query) + return instance.forward( + layer=layer, + query=query, + key=key, + value=value, + kv_cache=kv_cache.clone(), + attn_metadata=attn_metadata, + output=output, + ) + + +def test_tree_attn_correctness() -> None: + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + + device = "cuda" + tree_attn_masks = { + # Chain. + "[(0,), (0, 0), (0, 0, 0)]": + torch.tensor( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1], + ], + device=device, + dtype=torch.int32, + ), + # Tree. + "[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": + torch.tensor( + [ + [1, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 0, 1, 0, 0, 0, 0], + [1, 1, 0, 1, 0, 0, 0], + [1, 1, 0, 0, 1, 0, 0], + [1, 0, 1, 0, 0, 1, 0], + [1, 0, 1, 0, 0, 0, 1], + ], + device=device, + dtype=torch.int32, + ), + } + + dim_per_head = 128 + num_kv_heads = 2 + block_size = 128 + max_sequence_length = 8192 + randomize_blocks = True + for batch_size in [1, 16, 32]: + for num_heads in [2, 4]: + for sequence_position in [16, 1024, 2048]: + for spec_token_tree, tree_attn_mask in tree_attn_masks.items(): + # Assert that the number of heads is divisible + # by the number of KV heads. + assert num_heads % num_kv_heads == 0 + + # Initialize q, k, and v. + tree_size_q = tree_attn_mask.shape[0] + seqlen_k = sequence_position + tree_size_q + q = torch.randn( + (batch_size, tree_size_q, num_heads, dim_per_head), + device=device, + dtype=torch.bfloat16, + ) + k = torch.randn( + (batch_size, tree_size_q, num_kv_heads, dim_per_head), + device=device, + dtype=torch.bfloat16, + ) + v = torch.randn( + (batch_size, tree_size_q, num_kv_heads, dim_per_head), + device=device, + dtype=torch.bfloat16, + ) + + # Setup the block table and KV cache for paged KV. + assert max_sequence_length % block_size == 0 + max_blocks_per_batch = max_sequence_length // block_size + kv_cache = torch.randn( + ( + 2, + batch_size * max_blocks_per_batch, + block_size, + num_kv_heads, + dim_per_head, + ), + device=q.device, + dtype=torch.bfloat16, + ) + num_alloc_blocks_per_batch = math.ceil(seqlen_k / + block_size) + block_table = torch.zeros( + (batch_size, max_blocks_per_batch), + device=q.device, + dtype=torch.int32, + ) + block_ids = torch.arange( + 0, + batch_size * num_alloc_blocks_per_batch, + device=q.device, + dtype=torch.int32, + ) + if randomize_blocks: + # Randomize the block ids. + block_ids = block_ids[torch.randperm( + block_ids.numel())] + block_table[:, : + num_alloc_blocks_per_batch] = block_ids.view( + -1, num_alloc_blocks_per_batch) + + # Setup the slot mapping for the input KVs. + tree_positions = sequence_position + torch.arange( + 0, + tree_size_q, + device=q.device, + dtype=torch.int64, + ).repeat(batch_size, 1) + tree_slot_mapping = _gen_slot_mapping( + tree_positions, block_table, block_size) + + # Compute attention for the tree. + tree_attn_output = forward_attention( + q=q, + k=k, + v=v, + kv_cache=kv_cache, + block_table=block_table, + slot_mapping=tree_slot_mapping, + seqlen_k=seqlen_k, + backend=_Backend.TREE_ATTN, + spec_token_tree=spec_token_tree, + num_spec_tokens=tree_size_q - 1, + ).view(batch_size, -1, num_heads, dim_per_head) + + # Verify that the chain attention output for each + # branch of the tree (computed using FA3) matches + # the tree attention output. + for q_index in range(tree_size_q): + # Get the q, k, and v for the branch. + branch_mask = tree_attn_mask[q_index, :] + branch_indices = torch.nonzero(branch_mask, + as_tuple=True)[0] + q_len = branch_indices.shape[0] + q_branch = q[:, branch_indices] + k_branch = k[:, branch_indices] + v_branch = v[:, branch_indices] + + # Setup slot mapping for the branch. + branch_positions = sequence_position + torch.arange( + 0, + q_len, + device=q.device, + dtype=torch.int64, + ).repeat(batch_size, 1) + branch_slot_mapping = _gen_slot_mapping( + branch_positions, block_table, block_size) + + # Compute flash attention for the branch. + flash_attn_output = forward_attention( + q=q_branch, + k=k_branch, + v=v_branch, + kv_cache=kv_cache, + block_table=block_table, + slot_mapping=branch_slot_mapping, + seqlen_k=sequence_position + q_len, + backend=_Backend.FLASH_ATTN_VLLM_V1, + ).view(batch_size, -1, num_heads, dim_per_head) + + # Compare the outputs. + assert torch.allclose( + tree_attn_output[:, branch_indices], + flash_attn_output, + atol=7.81e-3, + ), (f"outputs are not close for " + f"batch_size: {batch_size}, " + f"num_heads: {num_heads}, " + f"sequence_position: {sequence_position}, " + f"tree_attn_mask: {tree_attn_mask}, " + f"q_index: {q_index}.") + + +def _gen_slot_mapping(positions: torch.Tensor, block_table: torch.Tensor, + block_size: int): + block_indices = positions // block_size + blocks = block_table.gather(dim=1, index=block_indices) + return (blocks * block_size + positions % block_size).view(-1) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index eb9c4f1c1..0fdba569f 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -55,6 +55,7 @@ def kernel_unified_attention_2d( block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] scale, # float32 k_scale, # float32 v_scale, # float32 @@ -66,10 +67,12 @@ def kernel_unified_attention_2d( query_stride_1: tl.int64, # int, should be equal to head_size output_stride_0: tl.int64, # int output_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int BLOCK_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool USE_SOFTCAP: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int stride_k_cache_0: tl.int64, # int @@ -144,6 +147,11 @@ def kernel_unified_attention_2d( mask=query_mask_1, other=0.0) + # query-query attention bias + if USE_QQ_BIAS: + qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] + # compute the length of the longest sequence prefix spanned by any # query token in the current q_block (q_block_local_idx) max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( @@ -223,6 +231,18 @@ def kernel_unified_attention_2d( if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) + if USE_QQ_BIAS: + # compute key positions relative to query section + key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE] + # load bias only for keys that correspond to queries + is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0 + qq_bias = tl.load( + qq_bias_row_ptrs + key_rel_pos[None, :], + mask=is_query_key[None, :], # avoid OOB for context keys + other=0.0, + ) + S += qq_bias + # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) @@ -275,6 +295,7 @@ def kernel_unified_attention_3d( block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] scale, # float32 k_scale, # float32 v_scale, # float32 @@ -284,10 +305,12 @@ def kernel_unified_attention_3d( block_table_stride: tl.int64, # int query_stride_0: tl.int64, # int query_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int BLOCK_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool USE_SOFTCAP: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int stride_k_cache_0: tl.int64, # int @@ -373,6 +396,11 @@ def kernel_unified_attention_3d( mask=query_mask_1, other=0.0) + # query-query attention bias + if USE_QQ_BIAS: + qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) # iterate through tiles within current segment @@ -442,6 +470,18 @@ def kernel_unified_attention_3d( if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) + if USE_QQ_BIAS: + # compute key positions relative to query section + key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE] + # load bias only for keys that correspond to queries + is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0 + qq_bias = tl.load( + qq_bias_row_ptrs + key_rel_pos[None, :], + mask=is_query_key[None, :], # avoid OOB for context keys + other=0.0, + ) + S += qq_bias + # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) @@ -586,6 +626,7 @@ def unified_attention( k_descale, v_descale, alibi_slopes=None, + qq_bias=None, ): assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -595,6 +636,7 @@ def unified_attention( "Block size must be at least 32 for fp8" use_alibi_slopes = alibi_slopes is not None + use_qq_bias = qq_bias is not None block_size = v.shape[1] num_seqs = len(seqused_k) @@ -630,6 +672,7 @@ def unified_attention( block_tables_ptr=block_table, seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, scale=softmax_scale, k_scale=k_descale, v_scale=v_descale, @@ -641,10 +684,12 @@ def unified_attention( query_stride_1=q.stride(1), output_stride_0=out.stride(0), output_stride_1=out.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, BLOCK_SIZE=block_size, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, USE_SOFTCAP=(softcap > 0), SLIDING_WINDOW=(1 + window_size[0]), stride_k_cache_0=k.stride(0), @@ -699,6 +744,7 @@ def unified_attention( block_tables_ptr=block_table, seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, scale=softmax_scale, k_scale=k_descale, v_scale=v_descale, @@ -708,10 +754,12 @@ def unified_attention( block_table_stride=block_table.stride(0), query_stride_0=q.stride(0), query_stride_1=q.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, BLOCK_SIZE=block_size, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, USE_SOFTCAP=(softcap > 0), SLIDING_WINDOW=(1 + window_size[0]), stride_k_cache_0=k.stride(0), diff --git a/vllm/config.py b/vllm/config.py index ee8f3dd98..871df455e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3049,6 +3049,19 @@ class SpeculativeConfig: f"num_speculative_tokens:{self.num_speculative_tokens}" f" must be divisible by {n_predict=}") + if self.speculative_token_tree is None: + # Generate chain of tokens. + self.speculative_token_tree = str([ + (i + 1) * (0, ) + for i in range(self.num_speculative_tokens) + ]) + else: + # Sort the token tree breadth-first. + tree_choices = ast.literal_eval( + self.speculative_token_tree) + self.speculative_token_tree = str( + sorted(tree_choices, key=lambda t: (len(t), t))) + self.draft_tensor_parallel_size = \ SpeculativeConfig._verify_and_get_draft_tp( self.target_parallel_config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c94e440e5..5eb9660cd 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1454,7 +1454,6 @@ class EngineArgs: "Please consider using other speculative decoding methods " "such as ngram, medusa, eagle, or deepseek_mtp.") - # No XFormers so far. V1_BACKENDS = [ "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", @@ -1469,6 +1468,7 @@ class EngineArgs: "ROCM_AITER_MLA", "TORCH_SDPA_VLLM_V1", "FLEX_ATTENTION", + "TREE_ATTN", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index a90910639..b61b39a92 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -270,6 +270,7 @@ class CudaPlatformBase(Platform): FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") @@ -287,6 +288,9 @@ class CudaPlatformBase(Platform): elif selected_backend == _Backend.FLASH_ATTN: logger.info_once("Using Flash Attention backend on V1 engine.") return FLASH_ATTN_V1 + elif selected_backend == _Backend.TREE_ATTN: + logger.info_once("Using Tree Attention backend on V1 engine.") + return TREE_ATTN_V1 from vllm.attention.selector import is_attn_backend_supported diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 997aee706..61ce868c1 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -62,6 +62,7 @@ class _Backend(enum.Enum): DIFFERENTIAL_FLASH_ATTN = enum.auto() NO_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto() + TREE_ATTN = enum.auto() class PlatformEnum(enum.Enum): diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py new file mode 100644 index 000000000..4fb748328 --- /dev/null +++ b/vllm/v1/attention/backends/tree_attn.py @@ -0,0 +1,452 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with TreeAttention.""" + +import ast +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.ops.triton_unified_attention import unified_attention +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + +from vllm import _custom_ops as ops + +logger = init_logger(__name__) + + +class TreeAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes.") + + @staticmethod + def get_name() -> str: + return "TREE_ATTN_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["TreeAttentionImpl"]: + return TreeAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return TreeAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]: + return TreeAttentionMetadataBuilder + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + +@dataclass +class TreeAttentionMetadata: + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + num_prefill_tokens: int = 0 + num_decode_tokens: int = 0 + num_prefills: int = 0 + num_decodes: int = 0 + + tree_attn_bias: Optional[torch.Tensor] = None + + # Cached Prefill/decode metadata. + _cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None + _cached_decode_metadata: Optional["TreeAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + # Recover cached prefill-phase attention + # metadata structure + return self._cached_prefill_metadata + + q_start_loc = self.query_start_loc[self.num_decodes:] + q_seqlens = torch.diff(q_start_loc) + kv_seqlens = self.seq_lens[self.num_decodes:] + # Construct & cache prefill-phase attention metadata structure + self._cached_prefill_metadata = TreeAttentionMetadata( + num_actual_tokens=self.num_prefill_tokens, + max_query_len=int(q_seqlens.max().item()), + query_start_loc=q_start_loc - q_start_loc[0], + max_seq_len=int(kv_seqlens.max().item()), + seq_lens=kv_seqlens, + block_table=self.block_table[self.num_decodes:], + slot_mapping=self.slot_mapping[self.num_decode_tokens:], + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["TreeAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + # Recover cached decode-phase attention + # metadata structure + return self._cached_decode_metadata + + q_start_loc = self.query_start_loc[:self.num_decodes + 1] + q_seqlens = torch.diff(q_start_loc) + kv_seqlens = self.seq_lens[:self.num_decodes] + # Construct & cache decode-phase attention metadata structure + self._cached_decode_metadata = TreeAttentionMetadata( + num_actual_tokens=self.num_decode_tokens, + max_query_len=int(q_seqlens.max().item()), + query_start_loc=q_start_loc, + max_seq_len=int(kv_seqlens.max().item()), + seq_lens=kv_seqlens, + block_table=self.block_table[:self.num_decodes], + slot_mapping=self.slot_mapping[:self.num_decode_tokens], + tree_attn_bias=self.tree_attn_bias, + ) + return self._cached_decode_metadata + + +class TreeAttentionMetadataBuilder( + AttentionMetadataBuilder[TreeAttentionMetadata]): + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + self.kv_cache_spec = kv_cache_spec + self.block_size = kv_cache_spec.block_size + + spec_config = vllm_config.speculative_config + spec_token_tree = (spec := spec_config) and spec.speculative_token_tree + tree_choices: list[tuple[int, + ...]] = (ast.literal_eval(spec_token_tree) + if spec_token_tree is not None else + [(0, )]) + # Construct the tree attention bias. + depth_counts = _get_depth_counts(tree_choices) + self.tree_attn_bias = _prepare_tree_attn_bias( + tree_choices, + depth_counts, + dtype=torch.float32, + device=device, + ) + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + return reorder_batch_to_split_decodes_and_prefills( + input_batch, + scheduler_output, + decode_threshold=self.tree_attn_bias.shape[0]) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> TreeAttentionMetadata: + decode_threshold = self.tree_attn_bias.shape[0] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=decode_threshold)) + + num_actual_tokens = common_attn_metadata.num_actual_tokens + q_start_loc = common_attn_metadata.query_start_loc + max_query_len = common_attn_metadata.max_query_len + kv_seqlens = common_attn_metadata.seq_lens + max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + return TreeAttentionMetadata( + num_actual_tokens=num_actual_tokens, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_decodes=num_decodes, + max_query_len=max_query_len, + query_start_loc=q_start_loc, + max_seq_len=max_seq_len, + seq_lens=kv_seqlens, + block_table=block_table, + slot_mapping=slot_mapping, + tree_attn_bias=self.tree_attn_bias, + ) + + def build_for_drafting( + self, + common_attn_metadata: CommonAttentionMetadata, + draft_index: int, + ) -> TreeAttentionMetadata: + # Cache the original tree attention bias. + orig_tree_attn_bias = self.tree_attn_bias + + if draft_index == 0: + # Use prefill for drafting at the root level. + self.tree_attn_bias = torch.empty(0) + else: + # Slice the tree attention bias for drafting. + query_len = common_attn_metadata.max_query_len + start, end = draft_index, draft_index + query_len + self.tree_attn_bias = self.tree_attn_bias[start:end, + start:end].contiguous() + + # Build attention bias. + attn_metadata = self.build(0, common_attn_metadata, fast_build=True) + + # Reset the tree attention bias to the original value. + self.tree_attn_bias = orig_tree_attn_bias + return attn_metadata + + +def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]: + # Count the number of choices at each depth of the tree. + depth_counts = [] + prev_depth = 0 + for path in sorted_tree_choices: + depth = len(path) + if depth != prev_depth: + depth_counts.append(0) + depth_counts[depth - 1] += 1 + prev_depth = depth + return depth_counts + + +def _prepare_tree_attn_bias( + sorted_tree_choices: list[tuple[int, ...]], + depth_counts: list[int], + dtype: Optional[torch.dtype], + device: Optional[torch.device], +) -> torch.Tensor: + # +1 comes from the additional root node. + tree_len = len(sorted_tree_choices) + 1 + tree_attn_mask = torch.full((tree_len, tree_len), + -torch.inf, + device=device, + dtype=dtype) + + # Set diagonal to all zeros. Each token should + # attend to itself. + mask_val = 0 + for i in range(tree_len): + tree_attn_mask[i, i] = mask_val + + # Set root to all zeros. All tokens attend to it. + tree_attn_mask[:, 0] = mask_val + + # Set all ancestors to zeros. + start = 0 + for i in range(len(depth_counts)): + for j in range(depth_counts[i]): + cur_tree_choice = sorted_tree_choices[start + j] + # Retrieve ancestor position. + if len(cur_tree_choice) == 1: + continue + ancestor_idx = [] + for c in range(len(cur_tree_choice) - 1): + ancestor_idx.append( + sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1) + tree_attn_mask[j + start + 1, ancestor_idx] = mask_val + start += depth_counts[i] + return tree_attn_mask + + +class TreeAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "TreeAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if logits_soft_cap is None: + # Setting logits_soft_cap to 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + + TreeAttentionBackend.validate_head_size(head_size) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TreeAttentionImpl.") + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: TreeAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with TreeAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for TreeAttentionImpl") + + if attn_metadata is None: + # Profiling run. + return output + + # Cache the input KVs. + key_cache, value_cache = kv_cache.unbind(0) + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + num_actual_tokens = attn_metadata.num_actual_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, + key.shape[1]) + if prefill_meta := attn_metadata.prefill_metadata: + unified_attention( + q=query[num_decode_tokens:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[num_decode_tokens:num_actual_tokens], + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + seqused_k=prefill_meta.seq_lens, + max_seqlen_k=prefill_meta.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=prefill_meta.block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + + if decode_meta := attn_metadata.decode_metadata: + unified_attention( + q=query[:num_decode_tokens], + k=key_cache, + v=value_cache, + out=output[:num_decode_tokens], + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_q=decode_meta.max_query_len, + seqused_k=decode_meta.seq_lens, + max_seqlen_k=decode_meta.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + qq_bias=decode_meta.tree_attn_bias, + window_size=self.sliding_window, + block_table=decode_meta.block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 48bd63222..7aeea40b2 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -214,6 +214,26 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): return self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata) + def build_for_drafting( + self, + common_attn_metadata: CommonAttentionMetadata, + draft_index: int, + ) -> M: + """ + Build attention metadata for draft model. Uses build by default. + + Args: + common_attn_metadata: The common attention metadata. + draft_index: The index of the current draft operation. + When speculating a chain of tokens, this index refers to the + draft attempt for the i-th token. + For tree-based attention, this index instead refers to the + draft attempt for the i-th level in the tree of tokens. + """ + return self.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + fast_build=True) + def use_cascade_attention( self, common_prefix_len: int, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 302126dbe..b2380bb3d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +from dataclasses import replace from typing import Optional import numpy as np @@ -17,6 +19,8 @@ from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, + TreeAttentionMetadataBuilder) from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata @@ -74,18 +78,52 @@ class EagleProposer: (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device) - # We need +1 here because the arange is used to set query_start_loc, - # which has one more element than batch_size. - self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + - 1, - device=device, - dtype=torch.int32) + + max_batch_size = vllm_config.scheduler_config.max_num_seqs + self.arange = torch.arange( + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. + max_batch_size + 1, + device=device, + dtype=torch.int32, + ) self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device) + # Parse the speculative token tree. + spec_token_tree = self.speculative_config.speculative_token_tree + self.tree_choices: list[tuple[int, + ...]] = ast.literal_eval(spec_token_tree) + tree_depth = len(self.tree_choices[-1]) + # Precompute per-level properties of the tree. + num_drafts_per_level = [0] * tree_depth + for node in self.tree_choices: + num_drafts_per_level[len(node) - 1] += 1 + self.cu_drafts_per_level = [num_drafts_per_level[0]] + self.child_drafts_per_level = [num_drafts_per_level[0]] + for level in range(1, tree_depth): + self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] + + num_drafts_per_level[level]) + self.child_drafts_per_level.append(num_drafts_per_level[level] // + num_drafts_per_level[level - 1]) + # Find the first level where the tree branches off into one or more + # children. + self.first_branching_level = None + for level in range(tree_depth): + if self.cu_drafts_per_level[level] > level + 1: + self.first_branching_level = level + break + # Precompute draft position offsets in flattened tree. + self.tree_draft_pos_offsets = torch.arange( + 1, + len(self.tree_choices) + 1, + device=device, + dtype=torch.int32, + ).repeat(max_batch_size, 1) + def propose( self, # [num_tokens] @@ -120,11 +158,9 @@ class EagleProposer: assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups - attn_metadata = self.runner.attn_metadata_builders[0].build( - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - fast_build=True, - ) + attn_metadata = self.runner.attn_metadata_builders[ + 0].build_for_drafting(common_attn_metadata=common_attn_metadata, + draft_index=0) # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. @@ -167,6 +203,22 @@ class EagleProposer: last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) + positions = target_positions[last_token_indices] + hidden_states = hidden_states[last_token_indices] + if self.first_branching_level == 0: + # Branching has occurred at the root level. Draft using tree + # attention. + draft_token_ids_list = self.propose_tree( + tree_root_level=0, + batch_size=batch_size, + logits=logits, + positions=positions, + hidden_states=hidden_states, + common_attn_metadata=common_attn_metadata, + ) + # [batch_size, num_tree_tokens] + return torch.cat(draft_token_ids_list, dim=1) + draft_token_ids = logits.argmax(dim=-1) # Early exit if there is only one draft token to be generated. @@ -178,16 +230,15 @@ class EagleProposer: # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. - # Currently FlashAttention is the only backend that supports - # multi-token eagle spec decode. This is because the code below + # Currently, only FlashAttention and TreeAttention support multi-token + # eagle spec decode. This is because the code below # makes assumptions about attn_metadata attributes available. - assert isinstance(attn_metadata, FlashAttentionMetadata) + assert isinstance(attn_metadata, + (FlashAttentionMetadata, TreeAttentionMetadata)) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - positions = target_positions[last_token_indices] - hidden_states = hidden_states[last_token_indices] if self.use_cuda_graph and \ batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) @@ -196,7 +247,7 @@ class EagleProposer: attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] - for _ in range(self.num_speculative_tokens - 1): + for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. @@ -265,7 +316,20 @@ class EagleProposer: logits = self.model.compute_logits(last_hidden_states[:batch_size], None) - # TODO(wenlong): get more than one token for tree attention + if self.first_branching_level == token_index + 1: + # Branching has occurred. The remaining tokens are drafted + # using tree attention. + draft_token_ids_list += self.propose_tree( + tree_root_level=token_index + 1, + batch_size=batch_size, + logits=logits, + positions=positions, + hidden_states=hidden_states, + common_attn_metadata=common_attn_metadata, + ) + # [batch_size, num_tree_tokens] + return torch.cat(draft_token_ids_list, dim=1) + draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) @@ -273,6 +337,175 @@ class EagleProposer: draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids + def propose_tree( + self, + tree_root_level: int, + batch_size: int, + # [num_tokens, vocab_size] + logits: torch.Tensor, + # [num_tokens] + positions: torch.Tensor, + # [num_tokens, hidden_size] + hidden_states: torch.Tensor, + common_attn_metadata: CommonAttentionMetadata, + ) -> list[torch.Tensor]: + tree_attn_metadata_builder = self.runner.attn_metadata_builders[0] + assert isinstance(tree_attn_metadata_builder, + TreeAttentionMetadataBuilder) + + total_num_drafts = self.cu_drafts_per_level[tree_root_level] + level_num_drafts = total_num_drafts + # Sample a draft token for each child at the tree root level. + num_children = self.child_drafts_per_level[tree_root_level] + if num_children == 1: + draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1) + else: + draft_token_ids = torch.topk(logits, num_children, + dim=-1).indices.view(batch_size, -1) + draft_token_ids_list = [draft_token_ids] + draft_hidden_states = hidden_states.view(batch_size, 1, -1) + + # Initialize empty tensors for concatenation with the level outputs. + tree_input_ids = torch.empty(0, + device=self.input_ids.device, + dtype=self.input_ids.dtype) + tree_positions = torch.empty(0, + device=self.positions.device, + dtype=self.positions.dtype) + tree_hidden_states = torch.empty(0, + device=self.hidden_states.device, + dtype=self.hidden_states.dtype) + # Precompute the draft token positions. + flattened_draft_positions = ( + positions.view(batch_size, -1) + + self.tree_draft_pos_offsets[:batch_size, :]) + tree_depth = len(self.cu_drafts_per_level) + for level in range(tree_root_level, tree_depth - 1): + # Get draft positions for RoPE. + draft_positions = positions + (level + 1) + exceeds_max_model_len = (positions + + total_num_drafts) >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_draft_positions = torch.where( + exceeds_max_model_len, + 0, + draft_positions, + ) + if level_num_drafts > 1: + # Repeat the positions for each draft at this level. + draft_positions = clamped_draft_positions.repeat_interleave( + level_num_drafts).reshape(batch_size, -1) + + if num_children > 1: + # Repeat draft hidden states for each child. + draft_hidden_states = draft_hidden_states.repeat_interleave( + num_children, dim=1) + + # Concatenate the draft tokens, positions, and hidden states. + tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], + dim=1) + tree_positions = torch.cat([tree_positions, draft_positions], + dim=1) + tree_hidden_states = torch.cat( + [tree_hidden_states, draft_hidden_states], dim=1) + + # Build new attention metadata for the next level of drafts. + # This is necessary to support tree attention. + query_len = total_num_drafts - tree_root_level + common_attn_metadata = replace( + common_attn_metadata, + query_start_loc=query_len * self.arange[:batch_size + 1], + seq_lens=common_attn_metadata.seq_lens + level_num_drafts, + num_actual_tokens=batch_size * query_len, + max_query_len=query_len, + ) + attn_metadata = tree_attn_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, + draft_index=tree_root_level + 1, + ) + + # Apply new attention metadata to all layers. + per_layer_attn_metadata = {} + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + + # Consider max model length. + attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, + self.max_model_len) + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize their overheads in attention. + attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) + + # Compute the slot mapping. + query_positions = flattened_draft_positions[:, level:level + + query_len] + block_numbers = query_positions // self.block_size + block_ids = attn_metadata.block_table.gather(dim=1, + index=block_numbers) + slot_mapping = (block_ids * self.block_size + + query_positions % self.block_size) + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID + attn_metadata.slot_mapping = slot_mapping.view(-1) + + # Copy inputs to buffer for cudagraph. + num_tokens = attn_metadata.num_actual_tokens + input_ids = tree_input_ids.view(-1) + self.input_ids[:num_tokens] = input_ids + self.positions[:num_tokens] = tree_positions.view(-1) + self.hidden_states[:num_tokens] = tree_hidden_states.view( + num_tokens, -1) + + if self.use_cuda_graph and \ + num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_tokens) + else: + num_input_tokens = num_tokens + # Run the model. + with set_forward_context(per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens): + last_hidden_states, hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], + inputs_embeds=None, + ) + + # Get the output hidden states for the draft tokens. + draft_hidden_states = hidden_states[:num_tokens].view( + batch_size, query_len, -1)[:, -level_num_drafts:] + draft_last_hidden_states = last_hidden_states[:num_tokens].view( + batch_size, query_len, -1)[:, -level_num_drafts:] + + # Get the output logits for the draft tokens. + logits = self.model.compute_logits( + draft_last_hidden_states.reshape(batch_size * level_num_drafts, + -1), + None, + ) + + # Sample a draft token for each child at the next tree level. + num_children = self.child_drafts_per_level[level + 1] + if num_children == 1: + draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1) + else: + draft_token_ids = torch.topk(logits, num_children, + dim=-1).indices.view( + batch_size, -1) + draft_token_ids_list.append(draft_token_ids) + + # Update the # drafts counters for the next tree level. + level_num_drafts = self.cu_drafts_per_level[level + + 1] - total_num_drafts + total_num_drafts = self.cu_drafts_per_level[level + 1] + + return draft_token_ids_list + def prepare_inputs( self, common_attn_metadata: CommonAttentionMetadata,