[Model Runner V2] Minor optimization for eagle input processing (#32535)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -827,20 +827,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_rejected: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.speculator is not None
|
||||
last_sampled_tokens = self.req_states.last_sampled_tokens[
|
||||
input_batch.idx_mapping
|
||||
]
|
||||
next_prefill_tokens = self.req_states.next_prefill_tokens[
|
||||
input_batch.idx_mapping
|
||||
]
|
||||
draft_tokens = self.speculator.propose(
|
||||
input_batch,
|
||||
last_hidden_states,
|
||||
aux_hidden_states,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
last_sampled_tokens,
|
||||
next_prefill_tokens,
|
||||
self.req_states.last_sampled_tokens,
|
||||
self.req_states.next_prefill_tokens,
|
||||
self.sampler.sampling_states.temperature.gpu,
|
||||
self.sampler.sampling_states.seeds.gpu,
|
||||
)
|
||||
|
||||
@@ -195,9 +195,9 @@ class EagleSpeculator:
|
||||
num_sampled: torch.Tensor,
|
||||
# [num_reqs]
|
||||
num_rejected: torch.Tensor,
|
||||
# [num_reqs]
|
||||
# [max_num_reqs]
|
||||
last_sampled: torch.Tensor,
|
||||
# [num_reqs]
|
||||
# [max_num_reqs]
|
||||
next_prefill_tokens: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
temperature: torch.Tensor,
|
||||
@@ -320,6 +320,7 @@ def _prepare_eagle_inputs_kernel(
|
||||
eagle_positions_ptr,
|
||||
target_input_ids_ptr,
|
||||
target_positions_ptr,
|
||||
idx_mapping_ptr,
|
||||
last_sampled_ptr,
|
||||
next_prefill_tokens_ptr,
|
||||
num_sampled_ptr,
|
||||
@@ -328,6 +329,8 @@ def _prepare_eagle_inputs_kernel(
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + batch_idx)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
query_len = query_end - query_start
|
||||
@@ -338,11 +341,11 @@ def _prepare_eagle_inputs_kernel(
|
||||
|
||||
num_sampled = tl.load(num_sampled_ptr + batch_idx)
|
||||
if num_sampled > 0:
|
||||
next_token = tl.load(last_sampled_ptr + batch_idx).to(tl.int32)
|
||||
next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
|
||||
else:
|
||||
# Chunked prefilling.
|
||||
# Get the next prefill token.
|
||||
next_token = tl.load(next_prefill_tokens_ptr + batch_idx)
|
||||
next_token = tl.load(next_prefill_tokens_ptr + req_state_idx)
|
||||
|
||||
# Shift target_input_ids by one.
|
||||
for i in range(1, query_len, BLOCK_SIZE):
|
||||
@@ -370,9 +373,9 @@ def prepare_eagle_inputs(
|
||||
num_sampled: torch.Tensor,
|
||||
# [num_reqs]
|
||||
num_rejected: torch.Tensor,
|
||||
# [num_reqs]
|
||||
# [max_num_reqs]
|
||||
last_sampled: torch.Tensor,
|
||||
# [num_reqs]
|
||||
# [max_num_reqs]
|
||||
next_prefill_tokens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = input_batch.num_reqs
|
||||
@@ -387,6 +390,7 @@ def prepare_eagle_inputs(
|
||||
input_buffers.positions,
|
||||
input_batch.input_ids,
|
||||
input_batch.positions,
|
||||
input_batch.idx_mapping,
|
||||
last_sampled,
|
||||
next_prefill_tokens,
|
||||
num_sampled,
|
||||
|
||||
Reference in New Issue
Block a user