[V1] Implement Cascade Attention (#11635)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -2,10 +2,14 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.utils import cdiv
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
|
||||
@@ -38,6 +42,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata:
|
||||
@@ -56,6 +64,15 @@ class FlashAttentionMetadata:
|
||||
seq_start_loc: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
cu_prefix_kv_lens: Optional[torch.Tensor]
|
||||
cu_suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
|
||||
@@ -169,21 +186,245 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.max_query_len,
|
||||
cu_seqlens_k=attn_metadata.seq_start_loc,
|
||||
max_seqlen_k=attn_metadata.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
)
|
||||
if not attn_metadata.use_cascade:
|
||||
# Regular attention (common case).
|
||||
flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.max_query_len,
|
||||
cu_seqlens_k=attn_metadata.seq_start_loc,
|
||||
max_seqlen_k=attn_metadata.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
)
|
||||
return output
|
||||
|
||||
# Cascade attention (rare case).
|
||||
cascade_attention(
|
||||
output[:num_actual_tokens],
|
||||
query[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
cu_query_lens=attn_metadata.query_start_loc,
|
||||
max_query_len=attn_metadata.max_query_len,
|
||||
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
|
||||
cu_prefix_kv_lens=attn_metadata.cu_prefix_kv_lens,
|
||||
cu_suffix_kv_lens=attn_metadata.cu_suffix_kv_lens,
|
||||
max_kv_len=attn_metadata.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
block_table=attn_metadata.block_table,
|
||||
common_prefix_len=attn_metadata.common_prefix_len,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def use_cascade_attention(
|
||||
common_prefix_len: int,
|
||||
query_lens: np.ndarray,
|
||||
num_query_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_alibi: bool,
|
||||
use_sliding_window: bool,
|
||||
num_sms: int,
|
||||
) -> bool:
|
||||
"""Decide whether to use cascade attention.
|
||||
|
||||
This function 1) checks whether cascade attention is supported with the
|
||||
given configuration, and 2) heuristically decides whether using cascade
|
||||
attention can improve performance.
|
||||
"""
|
||||
# Too short common prefix. Probably not worth using cascade attention.
|
||||
# We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
|
||||
# NOTE(woosuk): This is the common case. We should return False as soon as
|
||||
# possible to avoid any unnecessary computation.
|
||||
if common_prefix_len < 256:
|
||||
return False
|
||||
# Cascade attention is currently not supported with these variants.
|
||||
if use_alibi or use_sliding_window:
|
||||
return False
|
||||
# Too few queries. Probably not worth using cascade attention.
|
||||
# We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
|
||||
num_reqs = len(query_lens)
|
||||
if num_reqs < 8:
|
||||
return False
|
||||
|
||||
# Heuristics to decide whether using cascade attention is beneficial.
|
||||
# 1. When FlashDecoding is not used for normal attention, cascade attention
|
||||
# is likely to be faster since it saves memory bandwidth.
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
# The criteria for using FlashDecoding can be found in the following link:
|
||||
# https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
|
||||
use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
|
||||
and not use_alibi and np.all(query_lens == 1))
|
||||
if not use_flash_decoding:
|
||||
# Use cascade attention.
|
||||
return True
|
||||
|
||||
# 2. When FlashDecoding is used for normal attention, it is not clear
|
||||
# whether cascade attention is beneficial, because FlashDecoding can
|
||||
# launch more CTAs than cascade attention.
|
||||
# We use a simple performance model to compare the two methods.
|
||||
# NOTE(woosuk): The performance model is very rough and may not be
|
||||
# accurate.
|
||||
num_tokens = num_reqs
|
||||
# NOTE(woosuk): These are default tile sizes. flash-attn might use
|
||||
# different tile sizes (e.g., 64 or 256) depending on the configuration.
|
||||
q_tile_size = 128
|
||||
kv_tile_size = 128
|
||||
num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)
|
||||
|
||||
cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
|
||||
cascade_waves = cdiv(cascade_ctas, num_sms)
|
||||
cascade_time = cascade_waves * num_prefix_tiles
|
||||
|
||||
flash_decoding_ctas = (num_reqs * num_kv_heads *
|
||||
cdiv(num_queries_per_kv, q_tile_size))
|
||||
flash_decoding_ctas *= num_prefix_tiles
|
||||
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
||||
|
||||
# Use cascade attention if it is faster than FlashDecoding.
|
||||
return cascade_time < flash_decoding_time
|
||||
|
||||
|
||||
def cascade_attention(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cu_query_lens: torch.Tensor,
|
||||
max_query_len: int,
|
||||
cu_prefix_query_lens: torch.Tensor,
|
||||
cu_prefix_kv_lens: torch.Tensor,
|
||||
cu_suffix_kv_lens: torch.Tensor,
|
||||
max_kv_len: int,
|
||||
softmax_scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
sliding_window: Tuple[int, int],
|
||||
logits_soft_cap: float,
|
||||
block_table: torch.Tensor,
|
||||
common_prefix_len: int,
|
||||
) -> torch.Tensor:
|
||||
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
|
||||
# TODO: Support sliding window.
|
||||
assert sliding_window == (-1, -1), (
|
||||
"Cascade attention does not support sliding window.")
|
||||
|
||||
num_tokens = query.shape[0]
|
||||
block_size = key_cache.shape[-3]
|
||||
assert common_prefix_len % block_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // block_size
|
||||
assert num_common_kv_blocks > 0
|
||||
|
||||
# Process shared prefix.
|
||||
prefix_output, prefix_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_prefix_query_lens,
|
||||
cu_seqlens_k=cu_prefix_kv_lens,
|
||||
max_seqlen_q=num_tokens,
|
||||
max_seqlen_k=common_prefix_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
window_size=sliding_window,
|
||||
block_table=block_table[:1],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
|
||||
# Process suffix per query.
|
||||
suffix_output, suffix_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
cu_seqlens_k=cu_suffix_kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len - common_prefix_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=sliding_window,
|
||||
block_table=block_table[:, num_common_kv_blocks:],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
|
||||
suffix_lse)
|
||||
|
||||
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
) -> None:
|
||||
num_tokens = output.shape[0]
|
||||
num_query_heads = output.shape[1]
|
||||
head_size = output.shape[2]
|
||||
padded_head_size = triton.next_power_of_2(head_size)
|
||||
|
||||
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
||||
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
||||
output,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def merge_attn_states_kernel(
|
||||
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
PADDED_HEAD_SIZE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
num_tokens = tl.num_programs(0)
|
||||
head_idx = tl.program_id(1)
|
||||
num_heads = tl.num_programs(1)
|
||||
|
||||
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
|
||||
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
||||
max_lse = tl.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
|
||||
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||
head_mask = head_arange < HEAD_SIZE
|
||||
p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
mask=head_mask)
|
||||
s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
mask=head_mask)
|
||||
|
||||
# NOTE(woosuk): Be careful with the numerical stability.
|
||||
# We should compute the scale first, and then multiply it with the output.
|
||||
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
|
||||
p_scale = tl.exp(p_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
|
||||
s_scale = tl.exp(s_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
|
||||
out = p_out * p_scale + s_out * s_scale
|
||||
tl.store(output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
out,
|
||||
mask=head_mask)
|
||||
|
||||
@@ -8,7 +8,7 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens,
|
||||
hash_request_tokens)
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -278,6 +278,56 @@ class KVCacheManager:
|
||||
if block.ref_cnt == 0:
|
||||
self.free_block_queue.append(block)
|
||||
|
||||
def get_num_common_prefix_blocks(
|
||||
self,
|
||||
request: Request,
|
||||
num_running_requests: int,
|
||||
) -> int:
|
||||
"""Calculate the number of common prefix blocks shared by all requests
|
||||
in the RUNNING state.
|
||||
|
||||
The function determines this by selecting any request and iterating
|
||||
through its blocks. A block is considered a common prefix block if its
|
||||
`ref_cnt` equals the total number of requests in the RUNNING state.
|
||||
|
||||
NOTE(woosuk): The number of requests in the RUNNING state is **greater
|
||||
than or equal to** the number of requests scheduled in the current step.
|
||||
This is because the RUNNING state only indicates that:
|
||||
1. The request has not yet finished, and
|
||||
2. The request holds its blocks unfreed.
|
||||
|
||||
While all scheduled requests must be in the RUNNING state, the inverse
|
||||
is not necessarily true. There may be RUNNING requests that are not
|
||||
scheduled in the current step. As of 1/1/2025, the scheduler does not
|
||||
allow this case, but it is possible in the future, as we allow more
|
||||
flexible scheduling.
|
||||
|
||||
This can result in an edge case where the number of common prefix blocks
|
||||
is 0, even though all scheduled requests share a common prefix. This
|
||||
occurs because there may be unscheduled RUNNING requests that do not
|
||||
share the common prefix. Currently, this case cannot be easily detected,
|
||||
so the function returns 0 in such cases.
|
||||
|
||||
Args:
|
||||
request: Any request in the RUNNING state, used to identify the
|
||||
common prefix blocks.
|
||||
num_running_requests: The total number of requests in the RUNNING
|
||||
state. This can be different from the number of scheduled
|
||||
requests in the current step.
|
||||
|
||||
Returns:
|
||||
int: The number of common prefix blocks.
|
||||
"""
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
blocks = self.req_to_blocks[request.request_id]
|
||||
num_common_blocks = 0
|
||||
for block in blocks:
|
||||
if block.ref_cnt == num_running_requests:
|
||||
num_common_blocks += 1
|
||||
else:
|
||||
break
|
||||
return num_common_blocks
|
||||
|
||||
def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
|
||||
"""Get new blocks from the free block pool.
|
||||
|
||||
|
||||
@@ -262,6 +262,14 @@ class Scheduler:
|
||||
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
|
||||
len(scheduled_running_reqs) == len(self.running))
|
||||
|
||||
# Get the longest common prefix among all requests in the running queue.
|
||||
# This can be potentially used for cascade attention.
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||
any_request, len(self.running)))
|
||||
|
||||
# Construct the scheduler output.
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(req,
|
||||
@@ -287,6 +295,7 @@ class Scheduler:
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
preempted_req_ids=preempted_req_ids,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
@@ -594,6 +603,7 @@ class SchedulerOutput:
|
||||
num_scheduled_tokens: Dict[str, int]
|
||||
total_num_scheduled_tokens: int
|
||||
scheduled_encoder_inputs: Dict[str, List[int]]
|
||||
num_common_prefix_blocks: int
|
||||
|
||||
preempted_req_ids: Set[str]
|
||||
finished_req_ids: Set[str]
|
||||
|
||||
@@ -72,6 +72,8 @@ class GPUModelRunner:
|
||||
# Model-related.
|
||||
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
self.num_query_heads = model_config.get_num_attention_heads(
|
||||
parallel_config)
|
||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
@@ -118,6 +120,10 @@ class GPUModelRunner:
|
||||
self.cudagraph_batch_sizes = list(
|
||||
reversed(self.vllm_config.compilation_config.capture_sizes))
|
||||
|
||||
# Cache the device properties.
|
||||
self.device_properties = torch.cuda.get_device_properties(self.device)
|
||||
self.num_sms = self.device_properties.multi_processor_count
|
||||
|
||||
# Persistent buffers for CUDA graphs.
|
||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
@@ -131,7 +137,8 @@ class GPUModelRunner:
|
||||
device=self.device)
|
||||
|
||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||
self.arange_np = np.arange(max(self.max_num_reqs, self.max_model_len),
|
||||
self.arange_np = np.arange(max(self.max_num_reqs + 1,
|
||||
self.max_model_len),
|
||||
dtype=np.int32)
|
||||
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
||||
# a faster version of creating a new tensor every time. Thus, we should
|
||||
@@ -355,6 +362,88 @@ class GPUModelRunner:
|
||||
self.device, non_blocking=True)
|
||||
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
||||
self.device, non_blocking=True).long()
|
||||
|
||||
# Prepare for cascade attention if needed.
|
||||
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
|
||||
self.block_size)
|
||||
if common_prefix_len == 0:
|
||||
# Common case.
|
||||
use_cascade = False
|
||||
else:
|
||||
# NOTE(woosuk): Cascade attention uses two attention kernels: one
|
||||
# for the common prefix and the other for the rest. For the first
|
||||
# kernel, we concatenate all the query tokens (possibly from
|
||||
# different requests) and treat them as if they are from the same
|
||||
# request. Then, we use bi-directional attention to process the
|
||||
# common prefix in the KV cache. Importantly, this means that the
|
||||
# first kernel does not do any masking.
|
||||
|
||||
# Consider the following example:
|
||||
# Request 1's input query: [D, E, X]
|
||||
# Request 1's kv cache: [A, B, C, D, E, X]
|
||||
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
|
||||
# Request 2's input query: [E, Y]
|
||||
# Request 2's kv cache: [A, B, C, D, E, Y]
|
||||
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
||||
|
||||
# If we use [A, B, C, D, E] as the common prefix, then the
|
||||
# first kernel will compute the bi-directional attention between
|
||||
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
|
||||
# However, this is wrong because D in Request 1 should not attend to
|
||||
# E in the common prefix (i.e., we need masking).
|
||||
# To avoid this, [A, B, C, D] should be the common prefix.
|
||||
# That is, the common prefix should be capped by the minimum
|
||||
# num_computed_tokens among the requests, and plus one to include
|
||||
# the first token of the query.
|
||||
|
||||
# In practice, we use [A, B, C] as the common prefix, instead of
|
||||
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
|
||||
# num_computed_tokens, without plus one).
|
||||
# This is because of an implementation detail: We want to always
|
||||
# use two kernels for cascade attention. Let's imagine:
|
||||
# Request 3's input query: [D]
|
||||
# Request 3's kv cache: [A, B, C, D]
|
||||
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
||||
# If we use [A, B, C, D] as the common prefix for Request 1-3,
|
||||
# then Request 3 will be processed only by the first kernel,
|
||||
# and the second kernel will get an empty input. While this is not
|
||||
# a fundamental problem, our current implementation does not support
|
||||
# this case.
|
||||
common_prefix_len = min(
|
||||
common_prefix_len,
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
|
||||
# common_prefix_len should be a multiple of the block size.
|
||||
common_prefix_len = (common_prefix_len // self.block_size *
|
||||
self.block_size)
|
||||
use_cascade = FlashAttentionBackend.use_cascade_attention(
|
||||
common_prefix_len=common_prefix_len,
|
||||
query_lens=num_scheduled_tokens,
|
||||
num_query_heads=self.num_query_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
use_alibi=False, # FIXME
|
||||
use_sliding_window=self.sliding_window is not None,
|
||||
num_sms=self.num_sms,
|
||||
)
|
||||
|
||||
if use_cascade:
|
||||
# TODO: Optimize.
|
||||
cu_prefix_query_lens = torch.tensor(
|
||||
[0, total_num_scheduled_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
cu_suffix_kv_lens = (
|
||||
self.seq_start_loc_np[:num_reqs + 1] -
|
||||
self.arange_np[:num_reqs + 1] * common_prefix_len)
|
||||
cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to(
|
||||
self.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
cu_prefix_kv_lens = None
|
||||
cu_suffix_kv_lens = None
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
@@ -363,6 +452,11 @@ class GPUModelRunner:
|
||||
seq_start_loc=seq_start_loc,
|
||||
block_table=self.input_batch.block_table[:num_reqs],
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
cu_prefix_kv_lens=cu_prefix_kv_lens,
|
||||
cu_suffix_kv_lens=cu_suffix_kv_lens,
|
||||
)
|
||||
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
||||
# request in the batch. While we should not sample any token from this
|
||||
|
||||
Reference in New Issue
Block a user