[Spec Decode] Add Batch Parallel Ngram. Upto 8x lower overhead. (#24986)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -17,7 +17,7 @@ PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
|
||||
GREEDY_TEMPERATURE: tl.constexpr = -1
|
||||
# Maximum number of speculative draft tokens allowed per request in a single
|
||||
# step. This value is chosen to be large enough to handle typical use cases.
|
||||
MAX_SPEC_LEN = 32
|
||||
MAX_SPEC_LEN = 128
|
||||
|
||||
|
||||
class RejectionSampler(nn.Module):
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from numba import jit
|
||||
from numba import get_num_threads, jit, njit, prange, set_num_threads
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
@@ -26,55 +26,174 @@ class NgramProposer:
|
||||
# Maximum length of the model.
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
# Pre-allocate buffers for numba batch propose.
|
||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.valid_ngram_draft = np.zeros((max_num_seqs, self.k),
|
||||
dtype=np.int32)
|
||||
self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32)
|
||||
|
||||
# Threshold of total number of tokens in the batch to enable
|
||||
# multi-threading in numba batch propose.
|
||||
self.num_tokens_threshold = 8192
|
||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
cpu_count = os.cpu_count()
|
||||
# Max number of threads for numba parallel processing.
|
||||
if cpu_count:
|
||||
# Divide by 2 to use physical cores
|
||||
# and not logical cores (hyper-threading).
|
||||
# Cap the number of threads to 8 to avoid using too many threads
|
||||
# since other components like frontend (incl tokenization)
|
||||
# and Structured Outputs also use multiple threads.
|
||||
# TODO(ekagra-ranjan): bump up the cap from 1 to 8
|
||||
# when TP parallelization for ngram is implemented.
|
||||
self.num_numba_thread_available = min(1, (cpu_count // 2))
|
||||
# Divide by tp_size to ensure each tensor parallel rank
|
||||
# has some threads since all ranks will run this.
|
||||
self.num_numba_thread_available //= tp_size
|
||||
else:
|
||||
self.num_numba_thread_available = 1
|
||||
|
||||
# Trigger Numba JIT compilation for N-gram proposer.
|
||||
# This usually takes less than 1 second.
|
||||
self.propose(np.zeros(1024, dtype=np.int32))
|
||||
self.propose([[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32),
|
||||
np.zeros((1024, self.max_model_len), dtype=np.int32),
|
||||
set())
|
||||
|
||||
def batch_propose(
|
||||
self,
|
||||
num_requests: int,
|
||||
valid_ngram_requests: list,
|
||||
num_tokens_no_spec: np.ndarray,
|
||||
token_ids_cpu: np.ndarray,
|
||||
) -> list[list[int]]:
|
||||
"""Batch version of ngram proposer using numba for acceleration.
|
||||
|
||||
Args:
|
||||
valid_ngram_requests:
|
||||
Set of indices of requests that need ngram proposals.
|
||||
num_tokens_no_spec:
|
||||
Numpy array of shape (batch_size,) representing the number
|
||||
of tokens without speculative tokens for each request.
|
||||
token_ids_cpu:
|
||||
Numpy array of shape (batch_size, max_model_len)
|
||||
representing the token IDs for each request.
|
||||
|
||||
Returns:
|
||||
list[list[int]]:
|
||||
A list where each element is a list of proposed
|
||||
token IDs for the corresponding request.
|
||||
"""
|
||||
draft_token_ids: list[list[int]] = []
|
||||
|
||||
# Only run batch propose if there are requests needing ngram proposals.
|
||||
# avoid calling numba function with empty list which causes error
|
||||
# ValueError: cannot compute fingerprint of empty list
|
||||
if num_ngram_requests := len(valid_ngram_requests):
|
||||
original_num_numba_threads = get_num_threads()
|
||||
# Ensure we use at least one thread.
|
||||
# If total tokens is small, using multiple threads
|
||||
# may slow down due to overhead.
|
||||
total_tokens = np.sum(num_tokens_no_spec)
|
||||
if total_tokens >= self.num_tokens_threshold:
|
||||
final_num_threads = max(
|
||||
1, min(self.num_numba_thread_available,
|
||||
num_ngram_requests))
|
||||
set_num_threads(final_num_threads)
|
||||
else:
|
||||
set_num_threads(1)
|
||||
|
||||
batch_propose_numba(valid_ngram_requests, num_tokens_no_spec,
|
||||
token_ids_cpu, self.min_n, self.max_n,
|
||||
self.max_model_len, self.k,
|
||||
self.valid_ngram_draft,
|
||||
self.valid_ngram_num_drafts)
|
||||
|
||||
# Restore original number of threads.
|
||||
set_num_threads(original_num_numba_threads)
|
||||
|
||||
for i in range(num_requests):
|
||||
if i in valid_ngram_requests and \
|
||||
self.valid_ngram_num_drafts[i] > 0:
|
||||
draft_token_ids.append(self.valid_ngram_draft[
|
||||
i, :self.valid_ngram_num_drafts[i]].tolist())
|
||||
else:
|
||||
draft_token_ids.append([])
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
def propose(
|
||||
self,
|
||||
context_token_ids: np.ndarray,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Proposes the next sequence of tokens based on n-gram pattern
|
||||
matching in the context. The function finds matches of the last n
|
||||
tokens in the previous context, and returns k tokens that followed
|
||||
that match.
|
||||
|
||||
Args:
|
||||
context_token_ids: Numpy array of token IDs representing the
|
||||
context sequence.
|
||||
sampled_token_ids: list[list[int]],
|
||||
req_ids: list[str],
|
||||
num_tokens_no_spec: np.ndarray,
|
||||
token_ids_cpu: np.ndarray,
|
||||
spec_decode_unsupported_reqs: set,
|
||||
) -> list[list[int]]:
|
||||
|
||||
Returns:
|
||||
np.ndarray: The sequence of tokens that followed
|
||||
the matched n-gram in the context.
|
||||
None: If no matching n-gram pattern is found.
|
||||
# find which requests need ngram proposals
|
||||
valid_ngram_requests = []
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
num_sampled_ids = len(sampled_ids)
|
||||
if not num_sampled_ids:
|
||||
# Skip speculative decoding.
|
||||
continue
|
||||
|
||||
Example:
|
||||
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
|
||||
k = 4:
|
||||
- The last 3 (= max_n) tokens [4,2,3] cannot find a match.
|
||||
- The last 2 tokens [2,3] will be matched against the previous
|
||||
4 tokens [1,2,3,4].
|
||||
- Finding a match of [2,3] would return the tokens that
|
||||
followed that pattern. Here we will return [4,2,3] because
|
||||
we only have three tokens after the match.
|
||||
"""
|
||||
# TODO(woosuk): Optimize this.
|
||||
return _find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=context_token_ids,
|
||||
min_ngram=self.min_n,
|
||||
max_ngram=self.max_n,
|
||||
max_model_len=self.max_model_len,
|
||||
k=self.k)
|
||||
# Skip requests that require sampling parameters that are not
|
||||
# supported with speculative decoding.
|
||||
req_id = req_ids[i]
|
||||
if req_id in spec_decode_unsupported_reqs:
|
||||
continue
|
||||
|
||||
num_tokens = num_tokens_no_spec[i]
|
||||
if num_tokens >= self.max_model_len:
|
||||
# Skip requests that have already reached the max model length.
|
||||
continue
|
||||
|
||||
valid_ngram_requests.append(i)
|
||||
|
||||
draft_token_ids = self.batch_propose(
|
||||
len(sampled_token_ids),
|
||||
valid_ngram_requests,
|
||||
num_tokens_no_spec,
|
||||
token_ids_cpu,
|
||||
)
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
# No model to load.
|
||||
pass
|
||||
|
||||
|
||||
@njit(parallel=True)
|
||||
def batch_propose_numba(valid_ngram_requests: list,
|
||||
num_tokens_no_spec: np.ndarray,
|
||||
token_ids_cpu: np.ndarray, min_n: int, max_n: int,
|
||||
max_model_len: int, k: int,
|
||||
valid_ngram_draft: np.ndarray,
|
||||
valid_ngram_num_drafts: np.ndarray):
|
||||
for i in prange(len(valid_ngram_requests)):
|
||||
idx = valid_ngram_requests[i]
|
||||
num_tokens = num_tokens_no_spec[idx]
|
||||
context_token_ids = token_ids_cpu[idx, :num_tokens]
|
||||
drafter_output = _find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=context_token_ids,
|
||||
min_ngram=min_n,
|
||||
max_ngram=max_n,
|
||||
max_model_len=max_model_len,
|
||||
k=k)
|
||||
|
||||
valid_ngram_num_drafts[i] = drafter_output.shape[0]
|
||||
if len(drafter_output):
|
||||
valid_ngram_draft[i, :drafter_output.shape[0]] = drafter_output
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def _find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens: np.ndarray, min_ngram: int, max_ngram: int,
|
||||
max_model_len: int, k: int) -> Optional[np.ndarray]:
|
||||
def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray,
|
||||
min_ngram: int,
|
||||
max_ngram: int,
|
||||
max_model_len: int,
|
||||
k: int) -> np.ndarray:
|
||||
"""
|
||||
Find the longest n-gram which matches the suffix of the given tokens
|
||||
whose length is within [min_ngram, max_ngram] (inclusive).
|
||||
@@ -84,12 +203,12 @@ def _find_longest_matched_ngram_and_propose_tokens(
|
||||
# Do not generate draft tokens is context is shorter than minimum n-gram
|
||||
total_token = origin_tokens.shape[0]
|
||||
if total_token < min_ngram:
|
||||
return None
|
||||
return np.empty((0, ), dtype=origin_tokens.dtype)
|
||||
|
||||
# Do not generate draft tokens beyond the max model length.
|
||||
k = min(k, max_model_len - total_token)
|
||||
if k <= 0:
|
||||
return None
|
||||
return np.empty((0, ), dtype=origin_tokens.dtype)
|
||||
|
||||
# Flip tokens, and the goal become to find longest ngram
|
||||
# on the rightmost position which matches the prefix with
|
||||
@@ -146,7 +265,7 @@ def _find_longest_matched_ngram_and_propose_tokens(
|
||||
|
||||
if longest_ngram < min_ngram:
|
||||
# No valid ngram is found
|
||||
return None
|
||||
return np.empty((0, ), dtype=origin_tokens.dtype)
|
||||
|
||||
# Flip the position back, so in origin_tokens,
|
||||
# origin_tokens[total_token-1-position:total_token-1-position+longest_ngram]
|
||||
|
||||
@@ -2404,8 +2404,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.speculative_config.method == "ngram":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
draft_token_ids = self.propose_ngram_draft_token_ids(
|
||||
sampled_token_ids)
|
||||
draft_token_ids = self.drafter.propose(
|
||||
sampled_token_ids, self.input_batch.req_ids,
|
||||
self.input_batch.num_tokens_no_spec,
|
||||
self.input_batch.token_ids_cpu,
|
||||
self.input_batch.spec_decode_unsupported_reqs)
|
||||
elif self.speculative_config.method == "medusa":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
assert isinstance(self.drafter, MedusaProposer)
|
||||
@@ -2515,41 +2518,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
return draft_token_ids
|
||||
|
||||
def propose_ngram_draft_token_ids(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
) -> list[list[int]]:
|
||||
# TODO(woosuk): Optimize.
|
||||
req_ids = self.input_batch.req_ids
|
||||
draft_token_ids: list[list[int]] = []
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
num_sampled_ids = len(sampled_ids)
|
||||
if not num_sampled_ids:
|
||||
# Skip speculative decoding.
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
# Skip requests that require sampling parameters that are not
|
||||
# supported with speculative decoding.
|
||||
req_id = req_ids[i]
|
||||
if req_id in self.input_batch.spec_decode_unsupported_reqs:
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
num_tokens = self.input_batch.num_tokens_no_spec[i]
|
||||
if num_tokens >= self.max_model_len:
|
||||
# Skip requests that have already reached the max model length.
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
drafter_output = self.drafter.propose(
|
||||
self.input_batch.token_ids_cpu[i, :num_tokens])
|
||||
if drafter_output is None or len(drafter_output) == 0:
|
||||
draft_token_ids.append([])
|
||||
else:
|
||||
draft_token_ids.append(drafter_output.tolist())
|
||||
return draft_token_ids
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
allowed_config_names = {"load_config", "model_config"}
|
||||
for config_name, config_overrides in overrides.items():
|
||||
|
||||
Reference in New Issue
Block a user