diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index b481a1f16..3c5b99904 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -750,11 +750,11 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat self.max_model_len = self.model_config.max_model_len max_num_seqs = vllm_config.scheduler_config.max_num_seqs max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens - self.max_num_q_block = ( - self.max_model_len + self.q_block_size - 1 - ) // self.q_block_size + self.max_num_query_groups = cdiv(max_num_batched_tokens, self.q_block_size) + max_num_pages_per_seq = cdiv(self.max_model_len, self.block_size) + self.max_num_kv_indices = self.q_block_size * max_num_pages_per_seq self.persistent_kv_num_blocks = torch.empty( - self.max_num_q_block, dtype=torch.int32, device=device + self.max_num_query_groups, dtype=torch.int32, device=device ) self.persistent_offset_tensor = torch.empty( max_num_seqs, dtype=torch.int32, device=device @@ -828,12 +828,9 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat ) if self.persistent_kv_indices is None: - max_num_kv_block = ( - self.max_model_len + self.kv_block_size - 1 - ) // self.kv_block_size self.persistent_kv_indices = torch.empty( - self.max_model_len, - max_num_kv_block, + self.max_num_query_groups, + self.max_num_kv_indices, dtype=torch.int32, device=self.device, )