[Bugfix] Fix block size used in EAGLE slot mapping (#31540)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-01-01 22:32:30 -05:00
committed by GitHub
parent 27864a851c
commit ea53ca5e85

View File

@@ -71,7 +71,6 @@ class EagleProposer:
self.device = device
self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
@@ -470,22 +469,23 @@ class EagleProposer:
common_attn_metadata._num_computed_tokens_cpu += 1
# Compute the slot mapping.
block_size = attn_metadata_builder.kv_cache_spec.block_size
if self.uses_mrope:
# all dimensions of positions are the same
block_numbers = clamped_positions[0] // self.block_size
block_numbers = clamped_positions[0] // block_size
else:
block_numbers = clamped_positions // self.block_size
block_numbers = clamped_positions // block_size
block_ids = common_attn_metadata.block_table_tensor.gather(
dim=1, index=block_numbers.view(-1, 1)
)
block_ids = block_ids.view(-1)
if self.uses_mrope:
common_attn_metadata.slot_mapping = (
block_ids * self.block_size + clamped_positions[0] % self.block_size
block_ids * block_size + clamped_positions[0] % block_size
)
else:
common_attn_metadata.slot_mapping = (
block_ids * self.block_size + clamped_positions % self.block_size
block_ids * block_size + clamped_positions % block_size
)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
@@ -800,12 +800,11 @@ class EagleProposer:
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
# Compute the slot mapping.
block_size = tree_attn_metadata_builder.kv_cache_spec.block_size
query_positions = flattened_draft_positions[:, level : level + query_len]
block_numbers = query_positions // self.block_size
block_numbers = query_positions // block_size
block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
slot_mapping = (
block_ids * self.block_size + query_positions % self.block_size
)
slot_mapping = block_ids * block_size + query_positions % block_size
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.