[Cleanup] Remove obsolete spec decoding compatibility logic (#32003)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-01-08 21:44:18 -08:00
committed by GitHub
parent 7a05d2dc65
commit 29ce48221c
8 changed files with 45 additions and 75 deletions

View File

@@ -135,7 +135,6 @@ def benchmark_batched_propose(args):
block_sizes=[16], block_sizes=[16],
) )
dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req)) dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req))
dummy_input_batch.spec_decode_unsupported_reqs = ()
dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req
dummy_input_batch.token_ids_cpu = np.random.randint( dummy_input_batch.token_ids_cpu = np.random.randint(
0, 20, (args.num_req, args.num_token) 0, 20, (args.num_req, args.num_token)
@@ -151,10 +150,8 @@ def benchmark_batched_propose(args):
start = time.time() start = time.time()
runner.drafter.propose( runner.drafter.propose(
sampled_token_ids, sampled_token_ids,
dummy_input_batch.req_ids,
dummy_input_batch.num_tokens_no_spec, dummy_input_batch.num_tokens_no_spec,
dummy_input_batch.token_ids_cpu, dummy_input_batch.token_ids_cpu,
dummy_input_batch.spec_decode_unsupported_reqs,
) )
end = time.time() end = time.time()
print(f"Iteration time (s): {end - start}") print(f"Iteration time (s): {end - start}")

View File

@@ -940,27 +940,62 @@ def test_correct_decoded_token_preserves_valid_tokens():
( (
"eagle", "eagle",
"meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
"nm-testing/Llama3_2_1B_speculator.eagle3", {
"method": "eagle",
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
"num_speculative_tokens": 3,
},
0,
), ),
marks=large_gpu_mark(min_gb=32), marks=large_gpu_mark(min_gb=32),
id="eagle0",
),
pytest.param(
(
"eagle",
"meta-llama/Llama-3.2-1B-Instruct",
{
"method": "eagle",
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
"num_speculative_tokens": 3,
},
3,
),
marks=large_gpu_mark(min_gb=32),
id="eagle3",
),
pytest.param(
(
"ngram",
"meta-llama/Llama-3.2-1B-Instruct",
{
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
3,
),
marks=large_gpu_mark(min_gb=32),
id="ngram",
), ),
], ],
) )
@pytest.mark.parametrize("top_logprobs", [0, 3])
def test_spec_decode_logprobs( def test_spec_decode_logprobs(
logprobs_mode: LogprobsMode, logprobs_mode: LogprobsMode,
model_setup: tuple[str, str, str], model_setup: tuple[str, str, dict, int],
top_logprobs: int,
): ):
"""Spec decode logprobs should match those of the base model. """Spec decode logprobs should match those of the base model.
Args: Args:
logprobs_mode: logprobs mode. logprobs_mode: logprobs mode.
model_setup: Spec decode method, base model name, and model_setup: Tuple of (method, base model name,
draft model name. speculative_config dict, top_logprobs).
""" """
from vllm import LLM from vllm import LLM
method, model_name, spec_config, top_logprobs = model_setup
prompt = "Hello world " * 50 prompt = "Hello world " * 50
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
@@ -972,7 +1007,7 @@ def test_spec_decode_logprobs(
ignore_eos=False, ignore_eos=False,
presence_penalty=-1.0, presence_penalty=-1.0,
) )
method, model_name, spec_model_name = model_setup
max_model_len = 256 max_model_len = 256
# Run base LLM. # Run base LLM.
@@ -999,14 +1034,11 @@ def test_spec_decode_logprobs(
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
# Run spec decode LLM. # Run spec decode LLM.
# Add max_model_len to spec_config if not present
spec_config_with_len = {**spec_config, "max_model_len": max_model_len}
spec_llm = LLM( spec_llm = LLM(
model_name, model_name,
speculative_config={ speculative_config=spec_config_with_len,
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": max_model_len,
},
max_logprobs=5, max_logprobs=5,
max_model_len=max_model_len, max_model_len=max_model_len,
seed=42, seed=42,

View File

@@ -82,10 +82,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[1, 2, 3, 4, 5]]) token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0]], sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
) )
assert len(result[0]) == 0 assert len(result[0]) == 0
@@ -93,10 +91,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose( result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
sampled_token_ids=[[0]], sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
) )
assert len(result[0]) == 0 assert len(result[0]) == 0
@@ -104,10 +100,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
sampled_token_ids=[[0]], sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
) )
assert np.array_equal(result, np.array([[4, 1]])) assert np.array_equal(result, np.array([[4, 1]]))
@@ -116,10 +110,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
sampled_token_ids=[[0]], sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
) )
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]] assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]]
@@ -127,10 +119,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose( result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
sampled_token_ids=[[0]], sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
) )
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]] assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]]
@@ -138,10 +128,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose( result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
sampled_token_ids=[[0]], sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
) )
assert np.array_equal(result, np.array([[100, 1]])) assert np.array_equal(result, np.array([[100, 1]]))
@@ -149,10 +137,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[]]) token_ids_cpu = np.array([[]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0]], sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
) )
assert len(result[0]) == 0 assert len(result[0]) == 0
@@ -162,10 +148,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0], [1]], sampled_token_ids=[[0], [1]],
req_ids=["0", "1"],
num_tokens_no_spec=np.array([5, 3]), num_tokens_no_spec=np.array([5, 3]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
) )
assert len(result[0]) == 2 assert len(result[0]) == 2
assert np.array_equal(result[0], np.array([3, 1])) assert np.array_equal(result[0], np.array([3, 1]))
@@ -183,10 +167,8 @@ def test_ngram_proposer():
sampled_token_ids = [[2], [], [8]] # Empty list for request 1 simulates prefill sampled_token_ids = [[2], [], [8]] # Empty list for request 1 simulates prefill
result = proposer.propose( result = proposer.propose(
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
req_ids=["0", "1", "2"],
num_tokens_no_spec=num_tokens_no_spec, num_tokens_no_spec=num_tokens_no_spec,
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
) )
assert len(result) == 3 assert len(result) == 3
assert np.array_equal(result[0], [3, 1]) assert np.array_equal(result[0], [3, 1])
@@ -214,10 +196,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([input_1, input_2]) token_ids_cpu = np.array([input_1, input_2])
result = ngram_proposer.propose( result = ngram_proposer.propose(
sampled_token_ids=[[0], [1]], sampled_token_ids=[[0], [1]],
req_ids=["0", "1"],
num_tokens_no_spec=np.array([len(input_1), 3]), num_tokens_no_spec=np.array([len(input_1), 3]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
) )
assert len(result[0]) == 2 assert len(result[0]) == 2
assert np.array_equal(result[0], np.array([middle_integer + 2, middle_integer + 3])) assert np.array_equal(result[0], np.array([middle_integer + 2, middle_integer + 3]))

View File

@@ -55,10 +55,8 @@ class NgramProposer:
# This usually takes less than 1 second. # This usually takes less than 1 second.
self.propose( self.propose(
[[]] * 1024, [[]] * 1024,
[""] * 1024,
np.zeros(1024, dtype=np.int32), np.zeros(1024, dtype=np.int32),
np.zeros((1024, self.max_model_len), dtype=np.int32), np.zeros((1024, self.max_model_len), dtype=np.int32),
set(),
) )
def batch_propose( def batch_propose(
@@ -132,10 +130,8 @@ class NgramProposer:
def propose( def propose(
self, self,
sampled_token_ids: list[list[int]], sampled_token_ids: list[list[int]],
req_ids: list[str],
num_tokens_no_spec: np.ndarray, num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray, token_ids_cpu: np.ndarray,
spec_decode_unsupported_reqs: set,
) -> list[list[int]]: ) -> list[list[int]]:
# find which requests need ngram proposals # find which requests need ngram proposals
valid_ngram_requests = [] valid_ngram_requests = []
@@ -145,12 +141,6 @@ class NgramProposer:
# Skip speculative decoding. # Skip speculative decoding.
continue continue
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = req_ids[i]
if req_id in spec_decode_unsupported_reqs:
continue
num_tokens = num_tokens_no_spec[i] num_tokens = num_tokens_no_spec[i]
if num_tokens >= self.max_model_len: if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length. # Skip requests that have already reached the max model length.

View File

@@ -46,13 +46,7 @@ class SuffixDecodingProposer:
draft_token_ids.append([]) draft_token_ids.append([])
continue continue
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = input_batch.req_ids[i] req_id = input_batch.req_ids[i]
if req_id in input_batch.spec_decode_unsupported_reqs:
draft_token_ids.append([])
continue
num_tokens = input_batch.num_tokens_no_spec[i] num_tokens = input_batch.num_tokens_no_spec[i]
if num_tokens >= self.max_model_len: if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length. # Skip requests that have already reached the max model length.

View File

@@ -1,21 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
_SAMPLING_EPS = 1e-5
def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
"""True if request is incompatible with speculative decoding"""
return (
sampling_params.frequency_penalty != 0.0
or sampling_params.presence_penalty != 0.0
or sampling_params.repetition_penalty != 1.0
or sampling_params.min_p > _SAMPLING_EPS
or sampling_params.logprobs is not None
)
@triton.jit @triton.jit
def eagle_prepare_inputs_padded_kernel( def eagle_prepare_inputs_padded_kernel(

View File

@@ -22,7 +22,6 @@ from vllm.v1.sample.logits_processor import (
MoveDirectionality, MoveDirectionality,
) )
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable from vllm.v1.worker.block_table import MultiGroupBlockTable
@@ -176,9 +175,6 @@ class InputBatch:
self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: set[str] = set() self.top_k_reqs: set[str] = set()
# IDs of requests which do not support spec decoding
self.spec_decode_unsupported_reqs: set[str] = set()
# Frequency penalty related data structures # Frequency penalty related data structures
self.frequency_penalties = torch.empty( self.frequency_penalties = torch.empty(
(max_num_reqs,), dtype=torch.float, device=device (max_num_reqs,), dtype=torch.float, device=device
@@ -346,8 +342,6 @@ class InputBatch:
self.block_table.add_row(request.block_ids, req_index) self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params: if sampling_params := request.sampling_params:
if self.is_spec_decode and is_spec_decode_unsupported(sampling_params):
self.spec_decode_unsupported_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.GREEDY: if sampling_params.sampling_type == SamplingType.GREEDY:
# Should avoid division by zero later when apply_temperature. # Should avoid division by zero later when apply_temperature.
self.temperature_cpu[req_index] = 0.0 self.temperature_cpu[req_index] = 0.0
@@ -510,7 +504,6 @@ class InputBatch:
self.random_reqs.discard(req_id) self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id) self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id) self.top_k_reqs.discard(req_id)
self.spec_decode_unsupported_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id) self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id) self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id) self.repetition_penalties_reqs.discard(req_id)

View File

@@ -3620,10 +3620,8 @@ class GPUModelRunner(
assert isinstance(self.drafter, NgramProposer) assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
sampled_token_ids, sampled_token_ids,
self.input_batch.req_ids,
self.input_batch.num_tokens_no_spec, self.input_batch.num_tokens_no_spec,
self.input_batch.token_ids_cpu, self.input_batch.token_ids_cpu,
self.input_batch.spec_decode_unsupported_reqs,
) )
elif spec_config.method == "suffix": elif spec_config.method == "suffix":
assert isinstance(sampled_token_ids, list) assert isinstance(sampled_token_ids, list)