[v1] AttentionMetadata for each layer (#17394)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from vllm.config import (VllmConfig, get_current_vllm_config,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -394,16 +395,15 @@ class FlashInferMetadataBuilder:
|
||||
)
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int):
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
assert (self._num_decode_tokens +
|
||||
self._num_prefill_tokens == num_actual_tokens)
|
||||
page_size = self.runner.block_size
|
||||
device = self.runner.device
|
||||
qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
self.runner.device, non_blocking=True)
|
||||
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
|
||||
non_blocking=True)
|
||||
qo_indptr = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
|
||||
Reference in New Issue
Block a user