[Model Runner V2] Fix flex attention kv blocks calculation issue (#39353)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -750,11 +750,11 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_q_block = (
|
||||
self.max_model_len + self.q_block_size - 1
|
||||
) // self.q_block_size
|
||||
self.max_num_query_groups = cdiv(max_num_batched_tokens, self.q_block_size)
|
||||
max_num_pages_per_seq = cdiv(self.max_model_len, self.block_size)
|
||||
self.max_num_kv_indices = self.q_block_size * max_num_pages_per_seq
|
||||
self.persistent_kv_num_blocks = torch.empty(
|
||||
self.max_num_q_block, dtype=torch.int32, device=device
|
||||
self.max_num_query_groups, dtype=torch.int32, device=device
|
||||
)
|
||||
self.persistent_offset_tensor = torch.empty(
|
||||
max_num_seqs, dtype=torch.int32, device=device
|
||||
@@ -828,12 +828,9 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
|
||||
)
|
||||
|
||||
if self.persistent_kv_indices is None:
|
||||
max_num_kv_block = (
|
||||
self.max_model_len + self.kv_block_size - 1
|
||||
) // self.kv_block_size
|
||||
self.persistent_kv_indices = torch.empty(
|
||||
self.max_model_len,
|
||||
max_num_kv_block,
|
||||
self.max_num_query_groups,
|
||||
self.max_num_kv_indices,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user