[Model Runner V2] Fix slot_mapping after #25954 (#33046)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-01-25 18:29:49 -08:00
committed by GitHub
parent 22aeb43007
commit edf927bc9f
5 changed files with 40 additions and 23 deletions

View File

@@ -101,9 +101,6 @@ class CudaGraphManager:
kv_cache_config,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, kv_cache_config
)
# Warm up.
with set_forward_context(
@@ -112,7 +109,7 @@ class CudaGraphManager:
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings_by_layer,
slot_mapping=slot_mappings,
):
hidden_states = model(
input_ids=input_ids,
@@ -132,7 +129,7 @@ class CudaGraphManager:
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings_by_layer,
slot_mapping=slot_mappings,
),
torch.cuda.graph(graph, self.pool),
):
@@ -252,7 +249,7 @@ def prepare_inputs_to_capture(
attn_metadata_builders: list[AttentionMetadataBuilder],
max_model_len: int,
kv_cache_config: KVCacheConfig,
) -> tuple[dict[str, Any], torch.Tensor]:
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
@@ -269,6 +266,9 @@ def prepare_inputs_to_capture(
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, kv_cache_config
)
attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders,
@@ -282,4 +282,4 @@ def prepare_inputs_to_capture(
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
)
return attn_metadata, slot_mappings
return attn_metadata, slot_mappings_by_layer

View File

@@ -66,6 +66,8 @@ class InputBatch:
# layer_name -> Metadata
attn_metadata: dict[str, Any]
# layer_name -> slot_mapping
slot_mappings: dict[str, torch.Tensor]
# [total_num_logits]
logits_indices: torch.Tensor
@@ -133,6 +135,7 @@ class InputBatch:
mrope_positions=None,
inputs_embeds=None,
attn_metadata=None, # type: ignore
slot_mappings=None, # type: ignore
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,

View File

@@ -269,6 +269,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=input_batch.num_reqs,
@@ -282,6 +285,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config=self.kv_cache_config,
)
input_batch.attn_metadata = attn_metadata
input_batch.slot_mappings = slot_mappings_by_layer
@torch.inference_mode()
def _dummy_run(
@@ -345,6 +349,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.speculator.run_model(
self.max_num_tokens,
attn_metadata=None,
slot_mappings=None,
num_tokens_across_dp=num_tokens_across_dp,
)
torch.cuda.synchronize()
@@ -615,6 +620,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc,
self.input_buffers.positions[:num_tokens],
)
# Layer name -> slot mapping.
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
# Layer name -> attention metadata.
attn_metadata = build_attn_metadata(
@@ -655,6 +664,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
mrope_positions=mrope_positions,
inputs_embeds=None,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
@@ -882,14 +892,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.uses_mrope:
assert input_batch.mrope_positions is not None
positions = input_batch.mrope_positions
slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.idx_mapping,
input_batch.query_start_loc,
input_batch.positions[: input_batch.num_tokens],
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
@@ -897,7 +899,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): Support piecewise CUDA graph.
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings_by_layer,
slot_mapping=input_batch.slot_mappings,
):
self.kv_connector.pre_forward(scheduler_output)
hidden_states = self.model(

View File

@@ -13,7 +13,10 @@ from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
@@ -108,7 +111,8 @@ class EagleSpeculator:
def run_model(
self,
num_tokens: int,
attn_metadata: dict[str, Any],
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
with set_forward_context(
@@ -117,6 +121,7 @@ class EagleSpeculator:
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
):
ret_hidden_states = self.model(
input_ids=self.input_buffers.input_ids[:num_tokens],
@@ -134,6 +139,7 @@ class EagleSpeculator:
self,
num_reqs: int,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor | None,
) -> None:
pos = self.input_buffers.positions[:num_reqs]
@@ -142,7 +148,7 @@ class EagleSpeculator:
for step in range(1, self.num_speculative_steps):
# Run the eagle model.
last_hidden_states, hidden_states = self.run_model(
num_reqs, attn_metadata, num_tokens_across_dp
num_reqs, attn_metadata, slot_mappings, num_tokens_across_dp
)
logits = self.model.compute_logits(last_hidden_states)
@@ -235,6 +241,7 @@ class EagleSpeculator:
last_hidden_states, hidden_states = self.run_model(
num_tokens,
input_batch.attn_metadata,
input_batch.slot_mappings,
num_tokens_across_dp=None, # FIXME
)
sample_hidden_states = last_hidden_states[last_token_indices]
@@ -311,7 +318,12 @@ class EagleSpeculator:
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
self.generate_draft(num_reqs, attn_metadata, num_tokens_across_dp=None) # FIXME
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
self.generate_draft(
num_reqs, attn_metadata, slot_mappings_by_layer, num_tokens_across_dp=None
) # FIXME
return self.draft_tokens[:num_reqs]

View File

@@ -69,7 +69,7 @@ class EagleCudaGraphManager:
kv_cache_config: KVCacheConfig,
) -> None:
num_reqs = min(num_tokens, self.max_num_reqs)
attn_metadata = prepare_inputs_to_capture(
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
input_buffers,
@@ -81,13 +81,13 @@ class EagleCudaGraphManager:
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
generate_fn(num_tokens, attn_metadata, num_tokens_across_dp)
generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp)
# Capture the graph.
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, self.pool):
generate_fn(num_tokens, attn_metadata, num_tokens_across_dp)
generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp)
self.graphs[num_tokens] = graph
@torch.inference_mode()