[Cleanup] Remove obsolete spec decoding compatibility logic (#32003)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -135,7 +135,6 @@ def benchmark_batched_propose(args):
|
||||
block_sizes=[16],
|
||||
)
|
||||
dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req))
|
||||
dummy_input_batch.spec_decode_unsupported_reqs = ()
|
||||
dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req
|
||||
dummy_input_batch.token_ids_cpu = np.random.randint(
|
||||
0, 20, (args.num_req, args.num_token)
|
||||
@@ -151,10 +150,8 @@ def benchmark_batched_propose(args):
|
||||
start = time.time()
|
||||
runner.drafter.propose(
|
||||
sampled_token_ids,
|
||||
dummy_input_batch.req_ids,
|
||||
dummy_input_batch.num_tokens_no_spec,
|
||||
dummy_input_batch.token_ids_cpu,
|
||||
dummy_input_batch.spec_decode_unsupported_reqs,
|
||||
)
|
||||
end = time.time()
|
||||
print(f"Iteration time (s): {end - start}")
|
||||
|
||||
@@ -940,27 +940,62 @@ def test_correct_decoded_token_preserves_valid_tokens():
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"nm-testing/Llama3_2_1B_speculator.eagle3",
|
||||
{
|
||||
"method": "eagle",
|
||||
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
0,
|
||||
),
|
||||
marks=large_gpu_mark(min_gb=32),
|
||||
id="eagle0",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
{
|
||||
"method": "eagle",
|
||||
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
3,
|
||||
),
|
||||
marks=large_gpu_mark(min_gb=32),
|
||||
id="eagle3",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"ngram",
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
{
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
3,
|
||||
),
|
||||
marks=large_gpu_mark(min_gb=32),
|
||||
id="ngram",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("top_logprobs", [0, 3])
|
||||
def test_spec_decode_logprobs(
|
||||
logprobs_mode: LogprobsMode,
|
||||
model_setup: tuple[str, str, str],
|
||||
top_logprobs: int,
|
||||
model_setup: tuple[str, str, dict, int],
|
||||
):
|
||||
"""Spec decode logprobs should match those of the base model.
|
||||
|
||||
Args:
|
||||
logprobs_mode: logprobs mode.
|
||||
model_setup: Spec decode method, base model name, and
|
||||
draft model name.
|
||||
model_setup: Tuple of (method, base model name,
|
||||
speculative_config dict, top_logprobs).
|
||||
"""
|
||||
from vllm import LLM
|
||||
|
||||
method, model_name, spec_config, top_logprobs = model_setup
|
||||
|
||||
prompt = "Hello world " * 50
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
|
||||
@@ -972,7 +1007,7 @@ def test_spec_decode_logprobs(
|
||||
ignore_eos=False,
|
||||
presence_penalty=-1.0,
|
||||
)
|
||||
method, model_name, spec_model_name = model_setup
|
||||
|
||||
max_model_len = 256
|
||||
|
||||
# Run base LLM.
|
||||
@@ -999,14 +1034,11 @@ def test_spec_decode_logprobs(
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Run spec decode LLM.
|
||||
# Add max_model_len to spec_config if not present
|
||||
spec_config_with_len = {**spec_config, "max_model_len": max_model_len}
|
||||
spec_llm = LLM(
|
||||
model_name,
|
||||
speculative_config={
|
||||
"method": method,
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 3,
|
||||
"max_model_len": max_model_len,
|
||||
},
|
||||
speculative_config=spec_config_with_len,
|
||||
max_logprobs=5,
|
||||
max_model_len=max_model_len,
|
||||
seed=42,
|
||||
|
||||
@@ -82,10 +82,8 @@ def test_ngram_proposer():
|
||||
token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 0
|
||||
|
||||
@@ -93,10 +91,8 @@ def test_ngram_proposer():
|
||||
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
|
||||
result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 0
|
||||
|
||||
@@ -104,10 +100,8 @@ def test_ngram_proposer():
|
||||
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
|
||||
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert np.array_equal(result, np.array([[4, 1]]))
|
||||
|
||||
@@ -116,10 +110,8 @@ def test_ngram_proposer():
|
||||
token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
|
||||
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]]
|
||||
|
||||
@@ -127,10 +119,8 @@ def test_ngram_proposer():
|
||||
token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]]
|
||||
|
||||
@@ -138,10 +128,8 @@ def test_ngram_proposer():
|
||||
token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
|
||||
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert np.array_equal(result, np.array([[100, 1]]))
|
||||
|
||||
@@ -149,10 +137,8 @@ def test_ngram_proposer():
|
||||
token_ids_cpu = np.array([[]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 0
|
||||
|
||||
@@ -162,10 +148,8 @@ def test_ngram_proposer():
|
||||
token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
|
||||
sampled_token_ids=[[0], [1]],
|
||||
req_ids=["0", "1"],
|
||||
num_tokens_no_spec=np.array([5, 3]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 2
|
||||
assert np.array_equal(result[0], np.array([3, 1]))
|
||||
@@ -183,10 +167,8 @@ def test_ngram_proposer():
|
||||
sampled_token_ids = [[2], [], [8]] # Empty list for request 1 simulates prefill
|
||||
result = proposer.propose(
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
req_ids=["0", "1", "2"],
|
||||
num_tokens_no_spec=num_tokens_no_spec,
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result) == 3
|
||||
assert np.array_equal(result[0], [3, 1])
|
||||
@@ -214,10 +196,8 @@ def test_ngram_proposer():
|
||||
token_ids_cpu = np.array([input_1, input_2])
|
||||
result = ngram_proposer.propose(
|
||||
sampled_token_ids=[[0], [1]],
|
||||
req_ids=["0", "1"],
|
||||
num_tokens_no_spec=np.array([len(input_1), 3]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 2
|
||||
assert np.array_equal(result[0], np.array([middle_integer + 2, middle_integer + 3]))
|
||||
|
||||
@@ -55,10 +55,8 @@ class NgramProposer:
|
||||
# This usually takes less than 1 second.
|
||||
self.propose(
|
||||
[[]] * 1024,
|
||||
[""] * 1024,
|
||||
np.zeros(1024, dtype=np.int32),
|
||||
np.zeros((1024, self.max_model_len), dtype=np.int32),
|
||||
set(),
|
||||
)
|
||||
|
||||
def batch_propose(
|
||||
@@ -132,10 +130,8 @@ class NgramProposer:
|
||||
def propose(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
req_ids: list[str],
|
||||
num_tokens_no_spec: np.ndarray,
|
||||
token_ids_cpu: np.ndarray,
|
||||
spec_decode_unsupported_reqs: set,
|
||||
) -> list[list[int]]:
|
||||
# find which requests need ngram proposals
|
||||
valid_ngram_requests = []
|
||||
@@ -145,12 +141,6 @@ class NgramProposer:
|
||||
# Skip speculative decoding.
|
||||
continue
|
||||
|
||||
# Skip requests that require sampling parameters that are not
|
||||
# supported with speculative decoding.
|
||||
req_id = req_ids[i]
|
||||
if req_id in spec_decode_unsupported_reqs:
|
||||
continue
|
||||
|
||||
num_tokens = num_tokens_no_spec[i]
|
||||
if num_tokens >= self.max_model_len:
|
||||
# Skip requests that have already reached the max model length.
|
||||
|
||||
@@ -46,13 +46,7 @@ class SuffixDecodingProposer:
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
# Skip requests that require sampling parameters that are not
|
||||
# supported with speculative decoding.
|
||||
req_id = input_batch.req_ids[i]
|
||||
if req_id in input_batch.spec_decode_unsupported_reqs:
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
num_tokens = input_batch.num_tokens_no_spec[i]
|
||||
if num_tokens >= self.max_model_len:
|
||||
# Skip requests that have already reached the max model length.
|
||||
|
||||
@@ -1,21 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
|
||||
"""True if request is incompatible with speculative decoding"""
|
||||
return (
|
||||
sampling_params.frequency_penalty != 0.0
|
||||
or sampling_params.presence_penalty != 0.0
|
||||
or sampling_params.repetition_penalty != 1.0
|
||||
or sampling_params.min_p > _SAMPLING_EPS
|
||||
or sampling_params.logprobs is not None
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def eagle_prepare_inputs_padded_kernel(
|
||||
|
||||
@@ -22,7 +22,6 @@ from vllm.v1.sample.logits_processor import (
|
||||
MoveDirectionality,
|
||||
)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
|
||||
@@ -176,9 +175,6 @@ class InputBatch:
|
||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||
self.top_k_reqs: set[str] = set()
|
||||
|
||||
# IDs of requests which do not support spec decoding
|
||||
self.spec_decode_unsupported_reqs: set[str] = set()
|
||||
|
||||
# Frequency penalty related data structures
|
||||
self.frequency_penalties = torch.empty(
|
||||
(max_num_reqs,), dtype=torch.float, device=device
|
||||
@@ -346,8 +342,6 @@ class InputBatch:
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
if sampling_params := request.sampling_params:
|
||||
if self.is_spec_decode and is_spec_decode_unsupported(sampling_params):
|
||||
self.spec_decode_unsupported_reqs.add(req_id)
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Should avoid division by zero later when apply_temperature.
|
||||
self.temperature_cpu[req_index] = 0.0
|
||||
@@ -510,7 +504,6 @@ class InputBatch:
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.spec_decode_unsupported_reqs.discard(req_id)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
|
||||
@@ -3620,10 +3620,8 @@ class GPUModelRunner(
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
draft_token_ids = self.drafter.propose(
|
||||
sampled_token_ids,
|
||||
self.input_batch.req_ids,
|
||||
self.input_batch.num_tokens_no_spec,
|
||||
self.input_batch.token_ids_cpu,
|
||||
self.input_batch.spec_decode_unsupported_reqs,
|
||||
)
|
||||
elif spec_config.method == "suffix":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
|
||||
Reference in New Issue
Block a user