diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 6333075ed..63635640b 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -827,20 +827,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_rejected: torch.Tensor, ) -> torch.Tensor: assert self.speculator is not None - last_sampled_tokens = self.req_states.last_sampled_tokens[ - input_batch.idx_mapping - ] - next_prefill_tokens = self.req_states.next_prefill_tokens[ - input_batch.idx_mapping - ] draft_tokens = self.speculator.propose( input_batch, last_hidden_states, aux_hidden_states, num_sampled, num_rejected, - last_sampled_tokens, - next_prefill_tokens, + self.req_states.last_sampled_tokens, + self.req_states.next_prefill_tokens, self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.seeds.gpu, ) diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py index a208c4105..b4d1964f9 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle.py @@ -195,9 +195,9 @@ class EagleSpeculator: num_sampled: torch.Tensor, # [num_reqs] num_rejected: torch.Tensor, - # [num_reqs] + # [max_num_reqs] last_sampled: torch.Tensor, - # [num_reqs] + # [max_num_reqs] next_prefill_tokens: torch.Tensor, # [max_num_reqs] temperature: torch.Tensor, @@ -320,6 +320,7 @@ def _prepare_eagle_inputs_kernel( eagle_positions_ptr, target_input_ids_ptr, target_positions_ptr, + idx_mapping_ptr, last_sampled_ptr, next_prefill_tokens_ptr, num_sampled_ptr, @@ -328,6 +329,8 @@ def _prepare_eagle_inputs_kernel( BLOCK_SIZE: tl.constexpr, ): batch_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + query_start = tl.load(query_start_loc_ptr + batch_idx) query_end = tl.load(query_start_loc_ptr + batch_idx + 1) query_len = query_end - query_start @@ -338,11 +341,11 @@ def _prepare_eagle_inputs_kernel( num_sampled = tl.load(num_sampled_ptr + batch_idx) if num_sampled > 0: - next_token = tl.load(last_sampled_ptr + batch_idx).to(tl.int32) + next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32) else: # Chunked prefilling. # Get the next prefill token. - next_token = tl.load(next_prefill_tokens_ptr + batch_idx) + next_token = tl.load(next_prefill_tokens_ptr + req_state_idx) # Shift target_input_ids by one. for i in range(1, query_len, BLOCK_SIZE): @@ -370,9 +373,9 @@ def prepare_eagle_inputs( num_sampled: torch.Tensor, # [num_reqs] num_rejected: torch.Tensor, - # [num_reqs] + # [max_num_reqs] last_sampled: torch.Tensor, - # [num_reqs] + # [max_num_reqs] next_prefill_tokens: torch.Tensor, ) -> torch.Tensor: num_reqs = input_batch.num_reqs @@ -387,6 +390,7 @@ def prepare_eagle_inputs( input_buffers.positions, input_batch.input_ids, input_batch.positions, + input_batch.idx_mapping, last_sampled, next_prefill_tokens, num_sampled,