Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -101,9 +101,6 @@ class CudaGraphManager:
|
||||
kv_cache_config,
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, kv_cache_config
|
||||
)
|
||||
|
||||
# Warm up.
|
||||
with set_forward_context(
|
||||
@@ -112,7 +109,7 @@ class CudaGraphManager:
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings_by_layer,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
@@ -132,7 +129,7 @@ class CudaGraphManager:
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings_by_layer,
|
||||
slot_mapping=slot_mappings,
|
||||
),
|
||||
torch.cuda.graph(graph, self.pool),
|
||||
):
|
||||
@@ -252,7 +249,7 @@ def prepare_inputs_to_capture(
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
max_model_len: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> tuple[dict[str, Any], torch.Tensor]:
|
||||
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
||||
num_tokens_per_req = num_tokens // num_reqs
|
||||
|
||||
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
|
||||
@@ -269,6 +266,9 @@ def prepare_inputs_to_capture(
|
||||
|
||||
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
|
||||
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, kv_cache_config
|
||||
)
|
||||
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=attn_metadata_builders,
|
||||
@@ -282,4 +282,4 @@ def prepare_inputs_to_capture(
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
return attn_metadata, slot_mappings
|
||||
return attn_metadata, slot_mappings_by_layer
|
||||
|
||||
@@ -66,6 +66,8 @@ class InputBatch:
|
||||
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
# layer_name -> slot_mapping
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
|
||||
# [total_num_logits]
|
||||
logits_indices: torch.Tensor
|
||||
@@ -133,6 +135,7 @@ class InputBatch:
|
||||
mrope_positions=None,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=None, # type: ignore
|
||||
slot_mappings=None, # type: ignore
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
|
||||
@@ -269,6 +269,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
slot_mappings = self.block_tables.get_dummy_slot_mappings(
|
||||
input_batch.num_tokens
|
||||
)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
num_reqs=input_batch.num_reqs,
|
||||
@@ -282,6 +285,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
input_batch.attn_metadata = attn_metadata
|
||||
input_batch.slot_mappings = slot_mappings_by_layer
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
@@ -345,6 +349,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.speculator.run_model(
|
||||
self.max_num_tokens,
|
||||
attn_metadata=None,
|
||||
slot_mappings=None,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
@@ -615,6 +620,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
query_start_loc,
|
||||
self.input_buffers.positions[:num_tokens],
|
||||
)
|
||||
# Layer name -> slot mapping.
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
|
||||
# Layer name -> attention metadata.
|
||||
attn_metadata = build_attn_metadata(
|
||||
@@ -655,6 +664,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings_by_layer,
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
@@ -882,14 +892,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.uses_mrope:
|
||||
assert input_batch.mrope_positions is not None
|
||||
positions = input_batch.mrope_positions
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
input_batch.idx_mapping,
|
||||
input_batch.query_start_loc,
|
||||
input_batch.positions[: input_batch.num_tokens],
|
||||
)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
@@ -897,7 +899,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# TODO(woosuk): Support piecewise CUDA graph.
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings_by_layer,
|
||||
slot_mapping=input_batch.slot_mappings,
|
||||
):
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
hidden_states = self.model(
|
||||
|
||||
@@ -13,7 +13,10 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
@@ -108,7 +111,8 @@ class EagleSpeculator:
|
||||
def run_model(
|
||||
self,
|
||||
num_tokens: int,
|
||||
attn_metadata: dict[str, Any],
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
with set_forward_context(
|
||||
@@ -117,6 +121,7 @@ class EagleSpeculator:
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=self.input_buffers.input_ids[:num_tokens],
|
||||
@@ -134,6 +139,7 @@ class EagleSpeculator:
|
||||
self,
|
||||
num_reqs: int,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
) -> None:
|
||||
pos = self.input_buffers.positions[:num_reqs]
|
||||
@@ -142,7 +148,7 @@ class EagleSpeculator:
|
||||
for step in range(1, self.num_speculative_steps):
|
||||
# Run the eagle model.
|
||||
last_hidden_states, hidden_states = self.run_model(
|
||||
num_reqs, attn_metadata, num_tokens_across_dp
|
||||
num_reqs, attn_metadata, slot_mappings, num_tokens_across_dp
|
||||
)
|
||||
logits = self.model.compute_logits(last_hidden_states)
|
||||
|
||||
@@ -235,6 +241,7 @@ class EagleSpeculator:
|
||||
last_hidden_states, hidden_states = self.run_model(
|
||||
num_tokens,
|
||||
input_batch.attn_metadata,
|
||||
input_batch.slot_mappings,
|
||||
num_tokens_across_dp=None, # FIXME
|
||||
)
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
@@ -311,7 +318,12 @@ class EagleSpeculator:
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
self.generate_draft(num_reqs, attn_metadata, num_tokens_across_dp=None) # FIXME
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
self.generate_draft(
|
||||
num_reqs, attn_metadata, slot_mappings_by_layer, num_tokens_across_dp=None
|
||||
) # FIXME
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class EagleCudaGraphManager:
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
attn_metadata = prepare_inputs_to_capture(
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
input_buffers,
|
||||
@@ -81,13 +81,13 @@ class EagleCudaGraphManager:
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
# Warm up.
|
||||
generate_fn(num_tokens, attn_metadata, num_tokens_across_dp)
|
||||
generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp)
|
||||
|
||||
# Capture the graph.
|
||||
assert num_tokens not in self.graphs
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, self.pool):
|
||||
generate_fn(num_tokens, attn_metadata, num_tokens_across_dp)
|
||||
generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp)
|
||||
self.graphs[num_tokens] = graph
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
Reference in New Issue
Block a user