[Model Runner V2] Fix flex attention kv blocks calculation issue (#39353)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-04-09 13:07:43 -04:00
committed by GitHub
parent 9036d4c464
commit 56e19d7ee2

View File

@@ -750,11 +750,11 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
self.max_model_len = self.model_config.max_model_len
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.max_num_q_block = (
self.max_model_len + self.q_block_size - 1
) // self.q_block_size
self.max_num_query_groups = cdiv(max_num_batched_tokens, self.q_block_size)
max_num_pages_per_seq = cdiv(self.max_model_len, self.block_size)
self.max_num_kv_indices = self.q_block_size * max_num_pages_per_seq
self.persistent_kv_num_blocks = torch.empty(
self.max_num_q_block, dtype=torch.int32, device=device
self.max_num_query_groups, dtype=torch.int32, device=device
)
self.persistent_offset_tensor = torch.empty(
max_num_seqs, dtype=torch.int32, device=device
@@ -828,12 +828,9 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
)
if self.persistent_kv_indices is None:
max_num_kv_block = (
self.max_model_len + self.kv_block_size - 1
) // self.kv_block_size
self.persistent_kv_indices = torch.empty(
self.max_model_len,
max_num_kv_block,
self.max_num_query_groups,
self.max_num_kv_indices,
dtype=torch.int32,
device=self.device,
)