[Model Runner V2] Add probabilistic rejection sampling for spec decoding (#35461)
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
@@ -57,6 +57,10 @@ SpeculativeMethod = Literal[
|
|||||||
EagleModelTypes,
|
EagleModelTypes,
|
||||||
NgramGPUTypes,
|
NgramGPUTypes,
|
||||||
]
|
]
|
||||||
|
RejectionSampleMethod = Literal[
|
||||||
|
"strict",
|
||||||
|
"probabilistic",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@@ -171,6 +175,12 @@ class SpeculativeConfig:
|
|||||||
"""Load config for the draft model. If not specified, will use the load
|
"""Load config for the draft model. If not specified, will use the load
|
||||||
config from the target model."""
|
config from the target model."""
|
||||||
|
|
||||||
|
rejection_sample_method: RejectionSampleMethod = "strict"
|
||||||
|
"""Whether to use strict (target and draft sampled tokens match exactly)
|
||||||
|
or probabilistic rejection sampling. Both respect the target model
|
||||||
|
distribution, but the latter yields a higher acceptance rate at the cost
|
||||||
|
of more memory to cache draft logits."""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ from vllm.v1.worker.gpu.spec_decode import init_speculator
|
|||||||
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
|
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
|
||||||
set_eagle3_aux_hidden_state_layers,
|
set_eagle3_aux_hidden_state_layers,
|
||||||
)
|
)
|
||||||
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
|
from vllm.v1.worker.gpu.spec_decode.rejection_sampler import RejectionSampler
|
||||||
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
|
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
|
||||||
from vllm.v1.worker.gpu.states import RequestState
|
from vllm.v1.worker.gpu.states import RequestState
|
||||||
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
|
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
|
||||||
@@ -162,6 +162,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.speculator = None
|
self.speculator = None
|
||||||
self.num_speculative_steps = 0
|
self.num_speculative_steps = 0
|
||||||
self.use_aux_hidden_state_outputs = False
|
self.use_aux_hidden_state_outputs = False
|
||||||
|
use_strict_rejection_sampling = False
|
||||||
if self.speculative_config is not None:
|
if self.speculative_config is not None:
|
||||||
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
|
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
|
||||||
if self.is_last_pp_rank:
|
if self.is_last_pp_rank:
|
||||||
@@ -172,6 +173,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.use_aux_hidden_state_outputs = True
|
self.use_aux_hidden_state_outputs = True
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
raise ValueError("EAGLE3 with pipeline parallel is not supported.")
|
raise ValueError("EAGLE3 with pipeline parallel is not supported.")
|
||||||
|
use_strict_rejection_sampling = (
|
||||||
|
self.speculative_config.rejection_sample_method == "strict"
|
||||||
|
)
|
||||||
|
|
||||||
# Draft tokens propagation - for spec-dec + struct outputs.
|
# Draft tokens propagation - for spec-dec + struct outputs.
|
||||||
self.draft_tokens_handler = DraftTokensHandler(self.device)
|
self.draft_tokens_handler = DraftTokensHandler(self.device)
|
||||||
@@ -183,6 +187,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_speculative_steps=self.num_speculative_steps,
|
num_speculative_steps=self.num_speculative_steps,
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
model_dtype=self.dtype,
|
||||||
|
cache_draft_logits=not use_strict_rejection_sampling,
|
||||||
)
|
)
|
||||||
self.input_buffers = InputBuffers(
|
self.input_buffers = InputBuffers(
|
||||||
max_num_reqs=self.max_num_reqs,
|
max_num_reqs=self.max_num_reqs,
|
||||||
@@ -197,6 +203,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logprobs_mode=self.model_config.logprobs_mode,
|
logprobs_mode=self.model_config.logprobs_mode,
|
||||||
num_speculative_tokens=self.num_speculative_steps + 1,
|
num_speculative_tokens=self.num_speculative_steps + 1,
|
||||||
)
|
)
|
||||||
|
self.rejection_sampler = RejectionSampler(
|
||||||
|
self.sampler,
|
||||||
|
num_speculative_steps=self.num_speculative_steps,
|
||||||
|
use_strict_rejection_sampling=use_strict_rejection_sampling,
|
||||||
|
)
|
||||||
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
|
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
|
||||||
|
|
||||||
# CUDA graphs.
|
# CUDA graphs.
|
||||||
@@ -412,6 +423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
next_prefill_tokens=self.req_states.next_prefill_tokens,
|
next_prefill_tokens=self.req_states.next_prefill_tokens,
|
||||||
temperature=self.sampler.sampling_states.temperature.gpu,
|
temperature=self.sampler.sampling_states.temperature.gpu,
|
||||||
seeds=self.sampler.sampling_states.seeds.gpu,
|
seeds=self.sampler.sampling_states.seeds.gpu,
|
||||||
|
draft_logits_out=self.req_states.draft_logits,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
dummy_run=True,
|
dummy_run=True,
|
||||||
skip_attn_for_dummy_run=skip_attn,
|
skip_attn_for_dummy_run=skip_attn,
|
||||||
@@ -425,24 +437,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
|
def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
|
||||||
num_reqs = hidden_states.shape[0]
|
num_reqs = hidden_states.shape[0]
|
||||||
logits = self.model.compute_logits(hidden_states)
|
logits = self.model.compute_logits(hidden_states)
|
||||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device)
|
dummy_input_batch = InputBatch.make_dummy(
|
||||||
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
|
num_reqs, num_reqs, self.input_buffers
|
||||||
pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device)
|
|
||||||
dummy_input_ids = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
|
|
||||||
expanded_local_pos = torch.zeros(
|
|
||||||
num_reqs, dtype=torch.int32, device=self.device
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
|
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
|
||||||
# top_k, top_p, and logprobs, using less GPU memory than what is possible
|
# top_k, top_p, and logprobs, using less GPU memory than what is possible
|
||||||
# during actual execution.
|
# during actual execution.
|
||||||
self.sampler(
|
self.sampler(
|
||||||
logits,
|
logits,
|
||||||
idx_mapping,
|
dummy_input_batch,
|
||||||
idx_mapping_np,
|
|
||||||
idx_mapping_np,
|
|
||||||
pos,
|
|
||||||
dummy_input_ids,
|
|
||||||
expanded_local_pos,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@@ -768,8 +772,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
grammar_output: GrammarOutput | None,
|
grammar_output: GrammarOutput | None,
|
||||||
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
|
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
|
||||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||||
sample_pos = input_batch.positions[input_batch.logits_indices]
|
|
||||||
input_ids = input_batch.input_ids[input_batch.logits_indices]
|
|
||||||
logits = self.model.compute_logits(sample_hidden_states)
|
logits = self.model.compute_logits(sample_hidden_states)
|
||||||
if grammar_output is not None:
|
if grammar_output is not None:
|
||||||
# Apply grammar bitmask to the logits in-place.
|
# Apply grammar bitmask to the logits in-place.
|
||||||
@@ -780,34 +782,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
grammar_output.grammar_bitmask,
|
grammar_output.grammar_bitmask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sample tokens and compute logprobs (if needed).
|
|
||||||
sampler_output = self.sampler(
|
|
||||||
logits,
|
|
||||||
input_batch.expanded_idx_mapping,
|
|
||||||
input_batch.idx_mapping_np,
|
|
||||||
input_batch.cu_num_logits_np,
|
|
||||||
sample_pos,
|
|
||||||
input_ids,
|
|
||||||
input_batch.expanded_local_pos,
|
|
||||||
)
|
|
||||||
|
|
||||||
if input_batch.num_draft_tokens == 0:
|
if input_batch.num_draft_tokens == 0:
|
||||||
# No draft tokens (common case).
|
# No draft tokens (common case).
|
||||||
num_sampled = input_batch.seq_lens.new_ones(input_batch.num_reqs)
|
sampler_output = self.sampler(
|
||||||
|
logits,
|
||||||
|
input_batch,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Rejection sampling for spec decoding.
|
# Rejection sampling for spec decoding.
|
||||||
sampled_tokens, num_sampled = rejection_sample(
|
sampler_output = self.rejection_sampler(
|
||||||
sampler_output.sampled_token_ids,
|
logits,
|
||||||
input_ids,
|
input_batch,
|
||||||
input_batch.cu_num_logits,
|
# Draft logits are needed for probabilistic rejection sampling.
|
||||||
self.num_speculative_steps,
|
self.req_states.draft_logits[input_batch.idx_mapping]
|
||||||
|
if self.req_states.draft_logits is not None
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
sampler_output.sampled_token_ids = sampled_tokens
|
|
||||||
|
|
||||||
# Get the number of sampled and rejected tokens.
|
# Get the number of sampled and rejected tokens.
|
||||||
# For chunked prefills, num_sampled and num_rejected are both 0.
|
# For chunked prefills, num_sampled and num_rejected are both 0.
|
||||||
num_sampled, num_rejected = get_num_sampled_and_rejected(
|
num_sampled, num_rejected = get_num_sampled_and_rejected(
|
||||||
num_sampled,
|
sampler_output.num_sampled,
|
||||||
input_batch.seq_lens,
|
input_batch.seq_lens,
|
||||||
input_batch.cu_num_logits,
|
input_batch.cu_num_logits,
|
||||||
input_batch.idx_mapping,
|
input_batch.idx_mapping,
|
||||||
@@ -1105,6 +1100,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.req_states.next_prefill_tokens,
|
self.req_states.next_prefill_tokens,
|
||||||
self.sampler.sampling_states.temperature.gpu,
|
self.sampler.sampling_states.temperature.gpu,
|
||||||
self.sampler.sampling_states.seeds.gpu,
|
self.sampler.sampling_states.seeds.gpu,
|
||||||
|
self.req_states.draft_logits,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
)
|
)
|
||||||
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
|
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
|
||||||
|
|||||||
@@ -55,6 +55,8 @@ def _gumbel_sample_kernel(
|
|||||||
local_argmax_stride,
|
local_argmax_stride,
|
||||||
local_max_ptr,
|
local_max_ptr,
|
||||||
local_max_stride,
|
local_max_stride,
|
||||||
|
processed_logits_ptr,
|
||||||
|
processed_logits_stride,
|
||||||
logits_ptr,
|
logits_ptr,
|
||||||
logits_stride,
|
logits_stride,
|
||||||
expanded_idx_mapping_ptr,
|
expanded_idx_mapping_ptr,
|
||||||
@@ -79,6 +81,20 @@ def _gumbel_sample_kernel(
|
|||||||
logits = logits.to(tl.float32)
|
logits = logits.to(tl.float32)
|
||||||
|
|
||||||
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
||||||
|
if (temp != 0.0) and APPLY_TEMPERATURE:
|
||||||
|
# Apply temperature.
|
||||||
|
# NOTE(woosuk): Match the behavior of _temperature_kernel.
|
||||||
|
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
|
||||||
|
logits = logits / temp
|
||||||
|
|
||||||
|
# Store the temperature-applied logits.
|
||||||
|
if processed_logits_ptr is not None:
|
||||||
|
tl.store(
|
||||||
|
processed_logits_ptr + req_state_idx * processed_logits_stride + block,
|
||||||
|
logits,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
if temp != 0.0:
|
if temp != 0.0:
|
||||||
# Calculate the seed for gumbel noise.
|
# Calculate the seed for gumbel noise.
|
||||||
seed = tl.load(seeds_ptr + req_state_idx)
|
seed = tl.load(seeds_ptr + req_state_idx)
|
||||||
@@ -90,12 +106,6 @@ def _gumbel_sample_kernel(
|
|||||||
u = tl.maximum(u, 1e-7)
|
u = tl.maximum(u, 1e-7)
|
||||||
gumbel_noise = -tl.log(-tl.log(u))
|
gumbel_noise = -tl.log(-tl.log(u))
|
||||||
|
|
||||||
# Apply temperature.
|
|
||||||
if APPLY_TEMPERATURE:
|
|
||||||
# NOTE(woosuk): Match the behavior of _temperature_kernel.
|
|
||||||
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
|
|
||||||
logits = logits / temp
|
|
||||||
|
|
||||||
# Apply gumbel noise.
|
# Apply gumbel noise.
|
||||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||||
|
|
||||||
@@ -112,6 +122,7 @@ def gumbel_sample(
|
|||||||
seed: torch.Tensor, # [max_num_reqs]
|
seed: torch.Tensor, # [max_num_reqs]
|
||||||
pos: torch.Tensor, # [num_tokens]
|
pos: torch.Tensor, # [num_tokens]
|
||||||
apply_temperature: bool,
|
apply_temperature: bool,
|
||||||
|
processed_logits_out: torch.Tensor | None = None, # [num_reqs, vocab_size]
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_tokens, vocab_size = logits.shape
|
num_tokens, vocab_size = logits.shape
|
||||||
BLOCK_SIZE = 1024
|
BLOCK_SIZE = 1024
|
||||||
@@ -133,6 +144,8 @@ def gumbel_sample(
|
|||||||
local_argmax.stride(0),
|
local_argmax.stride(0),
|
||||||
local_max,
|
local_max,
|
||||||
local_max.stride(0),
|
local_max.stride(0),
|
||||||
|
processed_logits_out,
|
||||||
|
processed_logits_out.stride(0) if processed_logits_out is not None else 0,
|
||||||
logits,
|
logits,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
expanded_idx_mapping,
|
expanded_idx_mapping,
|
||||||
|
|||||||
@@ -12,3 +12,4 @@ class SamplerOutput:
|
|||||||
sampled_token_ids: torch.Tensor
|
sampled_token_ids: torch.Tensor
|
||||||
logprobs_tensors: LogprobsTensors | None
|
logprobs_tensors: LogprobsTensors | None
|
||||||
num_nans: torch.Tensor | None
|
num_nans: torch.Tensor | None
|
||||||
|
num_sampled: torch.Tensor | None
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config.model import LogprobsMode
|
from vllm.config.model import LogprobsMode
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||||
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
|
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
|
||||||
from vllm.v1.worker.gpu.sample.bad_words import BadWordsState
|
from vllm.v1.worker.gpu.sample.bad_words import BadWordsState
|
||||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||||
@@ -56,13 +57,15 @@ class Sampler:
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
expanded_idx_mapping: torch.Tensor,
|
input_batch: InputBatch,
|
||||||
idx_mapping_np: np.ndarray,
|
|
||||||
cu_num_logits_np: np.ndarray,
|
|
||||||
pos: torch.Tensor,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
expanded_local_pos: torch.Tensor,
|
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
|
expanded_idx_mapping = input_batch.expanded_idx_mapping
|
||||||
|
idx_mapping_np = input_batch.idx_mapping_np
|
||||||
|
cu_num_logits_np = input_batch.cu_num_logits_np
|
||||||
|
expanded_local_pos = input_batch.expanded_local_pos
|
||||||
|
pos = input_batch.positions[input_batch.logits_indices]
|
||||||
|
input_ids = input_batch.input_ids[input_batch.logits_indices]
|
||||||
|
|
||||||
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
|
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
|
||||||
# that num_nans is computed before applying penalties and temperature.
|
# that num_nans is computed before applying penalties and temperature.
|
||||||
num_nans = get_num_nans(logits) if self.compute_nans else None
|
num_nans = get_num_nans(logits) if self.compute_nans else None
|
||||||
@@ -95,10 +98,11 @@ class Sampler:
|
|||||||
sampled_token_ids=sampled.view(-1, 1),
|
sampled_token_ids=sampled.view(-1, 1),
|
||||||
logprobs_tensors=logprobs_tensors,
|
logprobs_tensors=logprobs_tensors,
|
||||||
num_nans=num_nans,
|
num_nans=num_nans,
|
||||||
|
num_sampled=input_batch.seq_lens.new_ones(input_batch.num_reqs),
|
||||||
)
|
)
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
def sample(
|
def apply_sampling_params(
|
||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
expanded_idx_mapping: torch.Tensor,
|
expanded_idx_mapping: torch.Tensor,
|
||||||
@@ -106,7 +110,7 @@ class Sampler:
|
|||||||
pos: torch.Tensor,
|
pos: torch.Tensor,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
expanded_local_pos: torch.Tensor,
|
expanded_local_pos: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
# Copy logits to a new FP32 tensor.
|
# Copy logits to a new FP32 tensor.
|
||||||
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
|
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
|
||||||
|
|
||||||
@@ -143,13 +147,31 @@ class Sampler:
|
|||||||
self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)
|
self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)
|
||||||
|
|
||||||
# Apply top_k and/or top_p. This might or might not return a new tensor.
|
# Apply top_k and/or top_p. This might or might not return a new tensor.
|
||||||
logits = self.sampling_states.apply_top_k_top_p(
|
return self.sampling_states.apply_top_k_top_p(
|
||||||
logits, expanded_idx_mapping, idx_mapping_np
|
logits, expanded_idx_mapping, idx_mapping_np
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
expanded_idx_mapping: torch.Tensor,
|
||||||
|
idx_mapping_np: np.ndarray,
|
||||||
|
pos: torch.Tensor,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
expanded_local_pos: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
processed_logits = self.apply_sampling_params(
|
||||||
|
logits,
|
||||||
|
expanded_idx_mapping,
|
||||||
|
idx_mapping_np,
|
||||||
|
pos,
|
||||||
|
input_ids,
|
||||||
|
expanded_local_pos,
|
||||||
|
)
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
sampled = gumbel_sample(
|
sampled = gumbel_sample(
|
||||||
logits,
|
processed_logits,
|
||||||
expanded_idx_mapping,
|
expanded_idx_mapping,
|
||||||
self.sampling_states.temperature.gpu,
|
self.sampling_states.temperature.gpu,
|
||||||
self.sampling_states.seeds.gpu,
|
self.sampling_states.seeds.gpu,
|
||||||
|
|||||||
@@ -140,6 +140,7 @@ class EagleSpeculator:
|
|||||||
slot_mappings: dict[str, torch.Tensor] | None,
|
slot_mappings: dict[str, torch.Tensor] | None,
|
||||||
num_tokens_across_dp: torch.Tensor | None,
|
num_tokens_across_dp: torch.Tensor | None,
|
||||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
|
draft_logits_out: torch.Tensor | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
pos = self.input_buffers.positions[:num_reqs]
|
pos = self.input_buffers.positions[:num_reqs]
|
||||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||||
@@ -166,6 +167,9 @@ class EagleSpeculator:
|
|||||||
self.seeds,
|
self.seeds,
|
||||||
pos + 1,
|
pos + 1,
|
||||||
apply_temperature=True,
|
apply_temperature=True,
|
||||||
|
processed_logits_out=draft_logits_out[:, step]
|
||||||
|
if draft_logits_out is not None
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
self.draft_tokens[:num_reqs, step] = draft_tokens
|
self.draft_tokens[:num_reqs, step] = draft_tokens
|
||||||
|
|
||||||
@@ -219,6 +223,8 @@ class EagleSpeculator:
|
|||||||
temperature: torch.Tensor,
|
temperature: torch.Tensor,
|
||||||
# [max_num_reqs]
|
# [max_num_reqs]
|
||||||
seeds: torch.Tensor,
|
seeds: torch.Tensor,
|
||||||
|
# [max_num_reqs, num_speculative_steps, vocab_size]
|
||||||
|
draft_logits_out: torch.Tensor | None,
|
||||||
num_tokens_across_dp: torch.Tensor | None = None,
|
num_tokens_across_dp: torch.Tensor | None = None,
|
||||||
dummy_run: bool = False,
|
dummy_run: bool = False,
|
||||||
skip_attn_for_dummy_run: bool = False,
|
skip_attn_for_dummy_run: bool = False,
|
||||||
@@ -271,6 +277,7 @@ class EagleSpeculator:
|
|||||||
idx_mapping.copy_(input_batch.idx_mapping)
|
idx_mapping.copy_(input_batch.idx_mapping)
|
||||||
self.temperature.copy_(temperature)
|
self.temperature.copy_(temperature)
|
||||||
self.seeds.copy_(seeds)
|
self.seeds.copy_(seeds)
|
||||||
|
|
||||||
# Gather the values and copy them to the pre-allocated buffers.
|
# Gather the values and copy them to the pre-allocated buffers.
|
||||||
pos = self.input_buffers.positions[:num_reqs]
|
pos = self.input_buffers.positions[:num_reqs]
|
||||||
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
|
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
|
||||||
@@ -283,7 +290,11 @@ class EagleSpeculator:
|
|||||||
self.seeds,
|
self.seeds,
|
||||||
pos + 1,
|
pos + 1,
|
||||||
apply_temperature=True,
|
apply_temperature=True,
|
||||||
|
processed_logits_out=draft_logits_out[:, 0]
|
||||||
|
if draft_logits_out is not None
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.num_speculative_steps == 1:
|
if self.num_speculative_steps == 1:
|
||||||
# Early exit.
|
# Early exit.
|
||||||
return draft_tokens.view(-1, 1)
|
return draft_tokens.view(-1, 1)
|
||||||
@@ -365,6 +376,7 @@ class EagleSpeculator:
|
|||||||
slot_mappings_updated,
|
slot_mappings_updated,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
cudagraph_runtime_mode=batch_desc.cg_mode,
|
cudagraph_runtime_mode=batch_desc.cg_mode,
|
||||||
|
draft_logits_out=draft_logits_out,
|
||||||
)
|
)
|
||||||
return self.draft_tokens[:num_reqs]
|
return self.draft_tokens[:num_reqs]
|
||||||
|
|
||||||
|
|||||||
@@ -1,62 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _rejection_sample_kernel(
|
|
||||||
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
|
|
||||||
sampled_stride,
|
|
||||||
num_sampled_ptr, # [num_reqs]
|
|
||||||
target_sampled_ptr, # [num_draft_tokens + num_reqs]
|
|
||||||
input_ids_ptr, # [num_draft_tokens + num_reqs]
|
|
||||||
cu_num_logits_ptr, # [num_reqs + 1]
|
|
||||||
):
|
|
||||||
req_idx = tl.program_id(0)
|
|
||||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
|
||||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
|
||||||
num_tokens = end_idx - start_idx
|
|
||||||
|
|
||||||
num_sampled = 0
|
|
||||||
rejected = False
|
|
||||||
for i in range(num_tokens - 1):
|
|
||||||
if not rejected:
|
|
||||||
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
|
|
||||||
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
|
|
||||||
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
|
|
||||||
num_sampled += 1
|
|
||||||
if target_sampled != draft_sampled:
|
|
||||||
rejected = True
|
|
||||||
if not rejected:
|
|
||||||
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
|
|
||||||
tl.store(
|
|
||||||
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
|
|
||||||
)
|
|
||||||
num_sampled += 1
|
|
||||||
tl.store(num_sampled_ptr + req_idx, num_sampled)
|
|
||||||
|
|
||||||
|
|
||||||
def rejection_sample(
|
|
||||||
# [num_draft_tokens + num_reqs]
|
|
||||||
target_sampled: torch.Tensor,
|
|
||||||
# [num_draft_tokens + num_reqs]
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
# [num_reqs + 1]
|
|
||||||
cu_num_logits: torch.Tensor,
|
|
||||||
num_speculative_steps: int,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
num_reqs = cu_num_logits.shape[0] - 1
|
|
||||||
sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
|
|
||||||
num_sampled = cu_num_logits.new_empty(num_reqs)
|
|
||||||
_rejection_sample_kernel[(num_reqs,)](
|
|
||||||
sampled,
|
|
||||||
sampled.stride(0),
|
|
||||||
num_sampled,
|
|
||||||
target_sampled,
|
|
||||||
input_ids,
|
|
||||||
cu_num_logits,
|
|
||||||
num_warps=1,
|
|
||||||
)
|
|
||||||
return sampled, num_sampled
|
|
||||||
375
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
Normal file
375
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
Normal file
@@ -0,0 +1,375 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||||
|
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
|
||||||
|
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||||
|
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||||
|
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _strict_rejection_sample_kernel(
|
||||||
|
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
|
||||||
|
sampled_stride,
|
||||||
|
num_sampled_ptr, # [num_reqs]
|
||||||
|
target_sampled_ptr, # [num_draft_tokens + num_reqs]
|
||||||
|
input_ids_ptr, # [num_draft_tokens + num_reqs]
|
||||||
|
cu_num_logits_ptr, # [num_reqs + 1]
|
||||||
|
):
|
||||||
|
req_idx = tl.program_id(0)
|
||||||
|
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||||
|
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||||
|
num_tokens = end_idx - start_idx
|
||||||
|
|
||||||
|
num_sampled = 0
|
||||||
|
rejected = False
|
||||||
|
for i in range(num_tokens - 1):
|
||||||
|
if not rejected:
|
||||||
|
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
|
||||||
|
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
|
||||||
|
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
|
||||||
|
num_sampled += 1
|
||||||
|
if target_sampled != draft_sampled:
|
||||||
|
rejected = True
|
||||||
|
if not rejected:
|
||||||
|
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
|
||||||
|
tl.store(
|
||||||
|
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
|
||||||
|
)
|
||||||
|
num_sampled += 1
|
||||||
|
tl.store(num_sampled_ptr + req_idx, num_sampled)
|
||||||
|
|
||||||
|
|
||||||
|
def strict_rejection_sample(
|
||||||
|
# [num_draft_tokens + num_reqs]
|
||||||
|
target_sampled: torch.Tensor,
|
||||||
|
# [num_draft_tokens + num_reqs]
|
||||||
|
draft_sampled: torch.Tensor,
|
||||||
|
# [num_reqs + 1]
|
||||||
|
cu_num_logits: torch.Tensor,
|
||||||
|
num_speculative_steps,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
num_reqs = cu_num_logits.shape[0] - 1
|
||||||
|
sampled = torch.empty(
|
||||||
|
num_reqs,
|
||||||
|
num_speculative_steps + 1,
|
||||||
|
dtype=target_sampled.dtype,
|
||||||
|
device=target_sampled.device,
|
||||||
|
)
|
||||||
|
num_sampled = torch.empty(
|
||||||
|
num_reqs,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=target_sampled.device,
|
||||||
|
)
|
||||||
|
_strict_rejection_sample_kernel[(num_reqs,)](
|
||||||
|
sampled,
|
||||||
|
sampled.stride(0),
|
||||||
|
num_sampled,
|
||||||
|
target_sampled,
|
||||||
|
draft_sampled,
|
||||||
|
cu_num_logits,
|
||||||
|
num_warps=1,
|
||||||
|
)
|
||||||
|
return sampled, num_sampled
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _probabilistic_rejection_sample_kernel(
|
||||||
|
# [num_reqs, num_speculative_steps + 1]
|
||||||
|
sampled_ptr,
|
||||||
|
sampled_stride,
|
||||||
|
# [num_reqs]
|
||||||
|
rejected_steps_ptr,
|
||||||
|
# [num_logits]
|
||||||
|
draft_sampled_ptr,
|
||||||
|
# [num_logits, V]
|
||||||
|
target_probs_ptr,
|
||||||
|
target_probs_stride,
|
||||||
|
# [num_reqs, num_speculative_steps, V]
|
||||||
|
draft_probs_ptr,
|
||||||
|
draft_probs_stride_0,
|
||||||
|
draft_probs_stride_1,
|
||||||
|
# [num_reqs + 1]
|
||||||
|
cu_num_logits_ptr,
|
||||||
|
# [num_logits]
|
||||||
|
pos_ptr,
|
||||||
|
# [num_reqs]
|
||||||
|
idx_mapping_ptr,
|
||||||
|
# [num_reqs]
|
||||||
|
seeds_ptr,
|
||||||
|
):
|
||||||
|
req_idx = tl.program_id(0)
|
||||||
|
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||||
|
num_tokens = tl.load(cu_num_logits_ptr + req_idx + 1) - start_idx
|
||||||
|
seed = tl.load(seeds_ptr + tl.load(idx_mapping_ptr + req_idx))
|
||||||
|
|
||||||
|
rejected_step = 0
|
||||||
|
accepted = True
|
||||||
|
for i in range(num_tokens - 1):
|
||||||
|
if accepted:
|
||||||
|
draft_sampled = tl.load(draft_sampled_ptr + start_idx + i + 1)
|
||||||
|
target_prob = tl.load(
|
||||||
|
target_probs_ptr + (start_idx + i) * target_probs_stride + draft_sampled
|
||||||
|
)
|
||||||
|
draft_prob = tl.load(
|
||||||
|
draft_probs_ptr
|
||||||
|
+ req_idx * draft_probs_stride_0
|
||||||
|
+ i * draft_probs_stride_1
|
||||||
|
+ draft_sampled
|
||||||
|
)
|
||||||
|
pos = tl.load(pos_ptr + start_idx + i)
|
||||||
|
u = tl.sum(tl.rand(seed, pos + tl.arange(0, 1)))
|
||||||
|
accepted &= target_prob > u * draft_prob
|
||||||
|
tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled)
|
||||||
|
rejected_step += accepted
|
||||||
|
tl.store(rejected_steps_ptr + req_idx, rejected_step)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _compute_residual_logits_kernel(
|
||||||
|
# [num_reqs, V]
|
||||||
|
residual_logits_ptr,
|
||||||
|
residual_logits_stride,
|
||||||
|
# [num_reqs]
|
||||||
|
residual_pos_ptr,
|
||||||
|
# [num_logits, V]
|
||||||
|
target_logits_ptr,
|
||||||
|
target_logits_stride,
|
||||||
|
# [num_logits, V]
|
||||||
|
target_probs_ptr,
|
||||||
|
target_probs_stride,
|
||||||
|
# [num_reqs, num_speculative_steps, V]
|
||||||
|
draft_probs_ptr,
|
||||||
|
draft_probs_stride_0,
|
||||||
|
draft_probs_stride_1,
|
||||||
|
# [num_reqs]
|
||||||
|
rejected_step_ptr,
|
||||||
|
# [num_reqs + 1]
|
||||||
|
cu_num_logits_ptr,
|
||||||
|
# [num_logits]
|
||||||
|
pos_ptr,
|
||||||
|
vocab_size,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
req_idx = tl.program_id(0)
|
||||||
|
block_idx = tl.program_id(1)
|
||||||
|
|
||||||
|
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||||
|
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||||
|
rejected_draft_step = tl.load(rejected_step_ptr + req_idx)
|
||||||
|
rejected_logit_idx = start_idx + rejected_draft_step
|
||||||
|
|
||||||
|
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = block_offsets < vocab_size
|
||||||
|
|
||||||
|
if rejected_logit_idx < end_idx - 1:
|
||||||
|
target_probs = tl.load(
|
||||||
|
target_probs_ptr + rejected_logit_idx * target_probs_stride + block_offsets,
|
||||||
|
mask=mask,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
draft_probs = tl.load(
|
||||||
|
draft_probs_ptr
|
||||||
|
+ req_idx * draft_probs_stride_0
|
||||||
|
+ rejected_draft_step * draft_probs_stride_1
|
||||||
|
+ block_offsets,
|
||||||
|
mask=mask,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
residual_probs = tl.maximum(target_probs - draft_probs, 0.0)
|
||||||
|
residual_logits = tl.log(residual_probs)
|
||||||
|
else:
|
||||||
|
# This is a bonus token. Directly return the target logits.
|
||||||
|
residual_logits = tl.load(
|
||||||
|
target_logits_ptr
|
||||||
|
+ rejected_logit_idx * target_logits_stride
|
||||||
|
+ block_offsets,
|
||||||
|
mask=mask,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
residual_logits_ptr + req_idx * residual_logits_stride + block_offsets,
|
||||||
|
residual_logits,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# First block computes the residual logit positions.
|
||||||
|
if block_idx == 0:
|
||||||
|
pos_val = tl.load(pos_ptr + rejected_logit_idx)
|
||||||
|
tl.store(residual_pos_ptr + req_idx, pos_val)
|
||||||
|
|
||||||
|
|
||||||
|
def probabilistic_rejection_sample(
|
||||||
|
# [num_draft_tokens + num_reqs, V]
|
||||||
|
target_logits: torch.Tensor,
|
||||||
|
# [num_reqs, num_speculative_steps, V]
|
||||||
|
draft_logits: torch.Tensor,
|
||||||
|
# [num_draft_tokens + num_reqs]
|
||||||
|
draft_sampled: torch.Tensor,
|
||||||
|
# [num_reqs + 1]
|
||||||
|
cu_num_logits: torch.Tensor,
|
||||||
|
# [num_logits]
|
||||||
|
pos: torch.Tensor,
|
||||||
|
# [num_reqs]
|
||||||
|
idx_mapping: torch.Tensor,
|
||||||
|
temperature,
|
||||||
|
seeds,
|
||||||
|
num_speculative_steps,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
num_reqs = cu_num_logits.shape[0] - 1
|
||||||
|
device = target_logits.device
|
||||||
|
vocab_size = target_logits.shape[-1]
|
||||||
|
|
||||||
|
# Compute target and draft probs.
|
||||||
|
target_probs = torch.softmax(target_logits, dim=-1)
|
||||||
|
draft_probs = torch.softmax(draft_logits, dim=-1)
|
||||||
|
|
||||||
|
# Rejection sample.
|
||||||
|
# [num_reqs, num_speculative_steps + 1]
|
||||||
|
sampled = torch.empty(
|
||||||
|
num_reqs,
|
||||||
|
num_speculative_steps + 1,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# [num_reqs]
|
||||||
|
rejected_steps = torch.empty(
|
||||||
|
num_reqs,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
_probabilistic_rejection_sample_kernel[(num_reqs,)](
|
||||||
|
sampled,
|
||||||
|
sampled.stride(0),
|
||||||
|
rejected_steps,
|
||||||
|
draft_sampled,
|
||||||
|
target_probs,
|
||||||
|
target_probs.stride(0),
|
||||||
|
draft_probs,
|
||||||
|
draft_probs.stride(0),
|
||||||
|
draft_probs.stride(1),
|
||||||
|
cu_num_logits,
|
||||||
|
pos,
|
||||||
|
idx_mapping,
|
||||||
|
seeds,
|
||||||
|
num_warps=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute the logits and positions to resample the rejected/bonus
|
||||||
|
# tokens from.
|
||||||
|
# [num_reqs, vocab_size]
|
||||||
|
residual_logits = torch.empty(
|
||||||
|
num_reqs,
|
||||||
|
vocab_size,
|
||||||
|
dtype=target_logits.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# [num_reqs]
|
||||||
|
residual_pos = torch.empty(
|
||||||
|
num_reqs,
|
||||||
|
dtype=pos.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
BLOCK_SIZE = 1024
|
||||||
|
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||||
|
_compute_residual_logits_kernel[(num_reqs, num_blocks)](
|
||||||
|
residual_logits,
|
||||||
|
residual_logits.stride(0),
|
||||||
|
residual_pos,
|
||||||
|
target_logits,
|
||||||
|
target_logits.stride(0),
|
||||||
|
target_probs,
|
||||||
|
target_probs.stride(0),
|
||||||
|
draft_probs,
|
||||||
|
draft_probs.stride(0),
|
||||||
|
draft_probs.stride(1),
|
||||||
|
rejected_steps,
|
||||||
|
cu_num_logits,
|
||||||
|
pos,
|
||||||
|
vocab_size,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gumbel sample tokens from the residual distribution.
|
||||||
|
resampled = gumbel_sample(
|
||||||
|
residual_logits,
|
||||||
|
idx_mapping,
|
||||||
|
temperature,
|
||||||
|
seeds,
|
||||||
|
residual_pos,
|
||||||
|
apply_temperature=False,
|
||||||
|
)
|
||||||
|
sampled.scatter_(1, rejected_steps.unsqueeze(1), resampled.unsqueeze(1))
|
||||||
|
|
||||||
|
return sampled, rejected_steps + 1
|
||||||
|
|
||||||
|
|
||||||
|
class RejectionSampler:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sampler: Sampler,
|
||||||
|
num_speculative_steps,
|
||||||
|
use_strict_rejection_sampling: bool = True,
|
||||||
|
):
|
||||||
|
self.sampler = sampler
|
||||||
|
self.num_speculative_steps = num_speculative_steps
|
||||||
|
self.use_strict_rejection_sampling = use_strict_rejection_sampling
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
input_batch: InputBatch,
|
||||||
|
draft_logits: torch.Tensor | None = None,
|
||||||
|
) -> SamplerOutput:
|
||||||
|
draft_sampled = input_batch.input_ids[input_batch.logits_indices]
|
||||||
|
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
|
||||||
|
# that num_nans is computed before applying penalties and temperature.
|
||||||
|
num_nans = get_num_nans(logits) if self.sampler.compute_nans else None
|
||||||
|
|
||||||
|
if self.use_strict_rejection_sampling:
|
||||||
|
sampler_output = self.sampler(
|
||||||
|
logits,
|
||||||
|
input_batch,
|
||||||
|
)
|
||||||
|
logprobs_tensors = sampler_output.logprobs_tensors
|
||||||
|
sampled, num_sampled = strict_rejection_sample(
|
||||||
|
sampler_output.sampled_token_ids.view(-1),
|
||||||
|
draft_sampled,
|
||||||
|
input_batch.cu_num_logits,
|
||||||
|
self.num_speculative_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert draft_logits is not None
|
||||||
|
pos = input_batch.positions[input_batch.logits_indices]
|
||||||
|
processed_logits = self.sampler.apply_sampling_params(
|
||||||
|
logits,
|
||||||
|
input_batch.expanded_idx_mapping,
|
||||||
|
input_batch.idx_mapping_np,
|
||||||
|
pos,
|
||||||
|
draft_sampled,
|
||||||
|
input_batch.expanded_local_pos,
|
||||||
|
)
|
||||||
|
# TODO (TheEpicDolphin): Return logprobs for sampled token ids.
|
||||||
|
logprobs_tensors = None
|
||||||
|
sampled, num_sampled = probabilistic_rejection_sample(
|
||||||
|
processed_logits,
|
||||||
|
draft_logits,
|
||||||
|
draft_sampled,
|
||||||
|
input_batch.cu_num_logits,
|
||||||
|
pos,
|
||||||
|
input_batch.idx_mapping,
|
||||||
|
self.sampler.sampling_states.temperature.gpu,
|
||||||
|
self.sampler.sampling_states.seeds.gpu,
|
||||||
|
self.num_speculative_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SamplerOutput(
|
||||||
|
sampled_token_ids=sampled,
|
||||||
|
logprobs_tensors=logprobs_tensors,
|
||||||
|
num_nans=num_nans,
|
||||||
|
num_sampled=num_sampled,
|
||||||
|
)
|
||||||
@@ -15,6 +15,8 @@ class RequestState:
|
|||||||
num_speculative_steps: int,
|
num_speculative_steps: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
model_dtype: torch.dtype,
|
||||||
|
cache_draft_logits: bool,
|
||||||
):
|
):
|
||||||
self.max_num_reqs = max_num_reqs
|
self.max_num_reqs = max_num_reqs
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
@@ -70,6 +72,19 @@ class RequestState:
|
|||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
# Draft token logits.
|
||||||
|
# NOTE: This tensor maintains the "processed" logits after applying temperature,
|
||||||
|
# top-p, etc.
|
||||||
|
self.draft_logits: torch.Tensor | None = None
|
||||||
|
if cache_draft_logits:
|
||||||
|
self.draft_logits = torch.zeros(
|
||||||
|
self.max_num_reqs,
|
||||||
|
self.num_speculative_steps,
|
||||||
|
self.vocab_size,
|
||||||
|
dtype=model_dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
self.next_prefill_tokens = torch.zeros(
|
self.next_prefill_tokens = torch.zeros(
|
||||||
self.max_num_reqs, dtype=torch.int32, device=device
|
self.max_num_reqs, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user