[V1][Spec Decode] Optimize Rejection Sampler with Triton Kernels (#14930)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -34,7 +34,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
@@ -149,7 +150,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_spec_decode = False
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
self.rejection_sampler = RejectionSampler()
|
||||
# TODO: find a better way to check if we are using ngram.
|
||||
assert self.speculative_config.ngram_prompt_lookup_min, \
|
||||
"Currently, only ngram spec decode is supported in V1."
|
||||
@@ -162,6 +162,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.speculative_config.ngram_prompt_lookup_min,
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
self.rejection_sampler = RejectionSampler()
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
@@ -452,7 +453,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[FlashAttentionMetadata, torch.Tensor]:
|
||||
) -> tuple[FlashAttentionMetadata, torch.Tensor,
|
||||
Optional[SpecDecodeMetadata]]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
@@ -577,22 +579,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if use_spec_decode:
|
||||
logits_indices = self._calc_spec_decode_metadata(
|
||||
scheduler_output, cu_num_tokens)
|
||||
else:
|
||||
if not use_spec_decode:
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
||||
# partial requests. While we should not sample any token
|
||||
# from these partial requests, we do so for simplicity.
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = attn_metadata.query_start_loc[1:] - 1
|
||||
spec_decode_metadata = None
|
||||
else:
|
||||
# Get the number of draft tokens for each request.
|
||||
# Iterate over the dictionary rather than all requests since not all
|
||||
# requests have draft tokens.
|
||||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||
for req_id, draft_token_ids in (
|
||||
scheduler_output.scheduled_spec_decode_tokens.items()):
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
|
||||
return attn_metadata, logits_indices
|
||||
return attn_metadata, logits_indices, spec_decode_metadata
|
||||
|
||||
def _compute_cascade_attn_prefix_len(
|
||||
self,
|
||||
@@ -732,49 +745,78 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def _calc_spec_decode_metadata(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
cu_num_tokens: np.ndarray,
|
||||
) -> torch.Tensor:
|
||||
# Get the number of spec decode tokens for each request.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
num_spec_decode_tokens[i] = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
|
||||
num_draft_tokens: np.ndarray,
|
||||
cu_num_scheduled_tokens: np.ndarray,
|
||||
) -> SpecDecodeMetadata:
|
||||
# Inputs:
|
||||
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
|
||||
# num_draft_tokens: [ 3, 0, 2, 0, 1]
|
||||
# Outputs:
|
||||
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
|
||||
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
|
||||
# 206, 207, 208]
|
||||
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
|
||||
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
|
||||
|
||||
# Get spec decode logits indices.
|
||||
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
|
||||
# cu_num_tokens: [4, 104, 107, 207, 209]
|
||||
# num_spec_tokens_list: [3, 0, 2, 0, 1]
|
||||
# num_sampled_tokens: [4, 1, 3, 1, 2]
|
||||
# spec_decode_logits_indices:
|
||||
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
num_sampled_tokens = num_spec_decode_tokens + 1
|
||||
# logits_start_loc: [0, 103, 104, 206, 207]
|
||||
logits_start_loc = cu_num_tokens - num_sampled_tokens
|
||||
# [0, 103, 104, 206, 207] ->
|
||||
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
||||
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
|
||||
# The following three lines:
|
||||
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
|
||||
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
|
||||
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
|
||||
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||
cumsums_sampled_offsets = np.repeat(
|
||||
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
|
||||
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
total_num_sampled_tokens = num_sampled_tokens.sum()
|
||||
sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
|
||||
cumsums_sampled_offsets)
|
||||
# Compute the logits indices.
|
||||
# [4, 1, 3, 1, 2]
|
||||
num_sampled_tokens = num_draft_tokens + 1
|
||||
# Step 1. [4, 5, 8, 9, 11]
|
||||
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
|
||||
total_num_sampled_tokens = cu_num_sampled_tokens[-1]
|
||||
# Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||
cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
|
||||
num_sampled_tokens)
|
||||
# Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
|
||||
# Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
||||
logits_indices = np.repeat(
|
||||
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
|
||||
# Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
logits_indices += arange
|
||||
|
||||
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
|
||||
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
spec_decode_logits_indices = logits_start_loc + sampled_arange
|
||||
return torch.from_numpy(spec_decode_logits_indices).to(
|
||||
# Compute the bonus logits indices.
|
||||
bonus_logits_indices = cu_num_sampled_tokens - 1
|
||||
|
||||
# Compute the draft logits indices.
|
||||
# [3, 3, 5, 5, 6]
|
||||
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
||||
total_num_draft_tokens = cu_num_draft_tokens[-1]
|
||||
# [0, 0, 0, 3, 3, 5]
|
||||
cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
|
||||
num_draft_tokens)
|
||||
# [0, 1, 2, 0, 1, 0]
|
||||
arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
|
||||
# [0, 0, 0, 5, 5, 9]
|
||||
target_logits_indices = np.repeat(
|
||||
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
|
||||
# [0, 1, 2, 5, 6, 9]
|
||||
target_logits_indices += arange
|
||||
|
||||
# TODO: Optimize the CPU -> GPU copy.
|
||||
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
|
||||
self.device, non_blocking=True)
|
||||
logits_indices = torch.from_numpy(logits_indices).to(self.device,
|
||||
non_blocking=True)
|
||||
target_logits_indices = torch.from_numpy(target_logits_indices).to(
|
||||
self.device, non_blocking=True)
|
||||
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
# Compute the draft token ids.
|
||||
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
|
||||
draft_token_ids = self.input_ids[logits_indices]
|
||||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||||
|
||||
metadata = SpecDecodeMetadata(
|
||||
draft_token_ids=draft_token_ids,
|
||||
num_draft_tokens=num_draft_tokens.tolist(),
|
||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||
target_logits_indices=target_logits_indices,
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
return metadata
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||
@@ -931,7 +973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
encoder_outputs = []
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
||||
self._prepare_inputs(scheduler_output))
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.use_cuda_graph
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
@@ -1006,31 +1049,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if not self.use_spec_decode:
|
||||
if spec_decode_metadata is None:
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
draft_token_ids = [
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||
for req_id in self.input_batch.req_ids
|
||||
]
|
||||
sample_lens = [len(tokens) + 1 for tokens in draft_token_ids]
|
||||
recover_logits_idx = np.cumsum(sample_lens) - 1
|
||||
target_probs = self.rejection_sampler.compute_probs(
|
||||
logits, sampling_metadata, sample_lens)
|
||||
# TODO(woosuk): Optimize the memory usage.
|
||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits[recover_logits_idx, :],
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
# TODO(woosuk): Optimize the memory usage.
|
||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
draft_token_ids,
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
target_probs,
|
||||
sampling_metadata)
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
@@ -1066,13 +1107,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_mask = sampled_token_ids != INVALID_TOKEN_ID
|
||||
gen_lens = valid_mask.sum(dim=1).tolist()
|
||||
# TODO(woosuk): Optimize this.
|
||||
valid_sampled_token_ids = [
|
||||
seq.tolist()
|
||||
for seq in sampled_token_ids[valid_mask].split(gen_lens)
|
||||
]
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids, self.input_batch.vocab_size)
|
||||
|
||||
if not self.use_spec_decode:
|
||||
spec_token_ids = None
|
||||
@@ -1316,6 +1352,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
"initializing the engine.") from e
|
||||
else:
|
||||
raise e
|
||||
if self.use_spec_decode:
|
||||
draft_token_ids = [[0] for _ in range(num_reqs)]
|
||||
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids, self.device)
|
||||
|
||||
num_tokens = sum(len(ids) for ids in draft_token_ids)
|
||||
# draft_probs = torch.randn(
|
||||
# num_tokens, logits.shape[-1], device=self.device,
|
||||
# dtype=logits.dtype)
|
||||
draft_probs = None
|
||||
target_logits = torch.randn(num_tokens,
|
||||
logits.shape[-1],
|
||||
device=self.device,
|
||||
dtype=logits.dtype)
|
||||
# NOTE(woosuk): Here, we should use int32 because the sampler uses
|
||||
# int32 for bonus_token_ids. If the dtype mismatches, re-compilation
|
||||
# will occur at runtime.
|
||||
bonus_token_ids = torch.zeros(num_reqs,
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
self.rejection_sampler(
|
||||
dummy_spec_decode_metadata,
|
||||
draft_probs,
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
dummy_metadata,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def profile_run(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user