[v1] AttentionMetadata for each layer (#17394)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-05-06 22:58:37 +08:00
committed by GitHub
parent a6fed02068
commit cba31c47c4
9 changed files with 126 additions and 46 deletions

View File

@@ -18,6 +18,7 @@ from vllm.config import (VllmConfig, get_current_vllm_config,
get_layers_from_vllm_config)
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -394,16 +395,15 @@ class FlashInferMetadataBuilder:
)
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
assert self._num_decodes + self._num_prefills == num_reqs
assert (self._num_decode_tokens +
self._num_prefill_tokens == num_actual_tokens)
page_size = self.runner.block_size
device = self.runner.device
qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
self.runner.device, non_blocking=True)
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
non_blocking=True)
qo_indptr = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(