[Model Runner V2] Minor optimization for eagle input processing (#32535)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2026-01-17 21:55:17 -08:00
committed by GitHub
parent 8cc26acd8b
commit 963dc0b865
2 changed files with 12 additions and 14 deletions

View File

@@ -827,20 +827,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_rejected: torch.Tensor,
) -> torch.Tensor:
assert self.speculator is not None
last_sampled_tokens = self.req_states.last_sampled_tokens[
input_batch.idx_mapping
]
next_prefill_tokens = self.req_states.next_prefill_tokens[
input_batch.idx_mapping
]
draft_tokens = self.speculator.propose(
input_batch,
last_hidden_states,
aux_hidden_states,
num_sampled,
num_rejected,
last_sampled_tokens,
next_prefill_tokens,
self.req_states.last_sampled_tokens,
self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
)

View File

@@ -195,9 +195,9 @@ class EagleSpeculator:
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [num_reqs]
# [max_num_reqs]
last_sampled: torch.Tensor,
# [num_reqs]
# [max_num_reqs]
next_prefill_tokens: torch.Tensor,
# [max_num_reqs]
temperature: torch.Tensor,
@@ -320,6 +320,7 @@ def _prepare_eagle_inputs_kernel(
eagle_positions_ptr,
target_input_ids_ptr,
target_positions_ptr,
idx_mapping_ptr,
last_sampled_ptr,
next_prefill_tokens_ptr,
num_sampled_ptr,
@@ -328,6 +329,8 @@ def _prepare_eagle_inputs_kernel(
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
@@ -338,11 +341,11 @@ def _prepare_eagle_inputs_kernel(
num_sampled = tl.load(num_sampled_ptr + batch_idx)
if num_sampled > 0:
next_token = tl.load(last_sampled_ptr + batch_idx).to(tl.int32)
next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
else:
# Chunked prefilling.
# Get the next prefill token.
next_token = tl.load(next_prefill_tokens_ptr + batch_idx)
next_token = tl.load(next_prefill_tokens_ptr + req_state_idx)
# Shift target_input_ids by one.
for i in range(1, query_len, BLOCK_SIZE):
@@ -370,9 +373,9 @@ def prepare_eagle_inputs(
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [num_reqs]
# [max_num_reqs]
last_sampled: torch.Tensor,
# [num_reqs]
# [max_num_reqs]
next_prefill_tokens: torch.Tensor,
) -> torch.Tensor:
num_reqs = input_batch.num_reqs
@@ -387,6 +390,7 @@ def prepare_eagle_inputs(
input_buffers.positions,
input_batch.input_ids,
input_batch.positions,
input_batch.idx_mapping,
last_sampled,
next_prefill_tokens,
num_sampled,