diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py index b5373d383..57a6c1aef 100644 --- a/benchmarks/benchmark_ngram_proposer.py +++ b/benchmarks/benchmark_ngram_proposer.py @@ -135,7 +135,6 @@ def benchmark_batched_propose(args): block_sizes=[16], ) dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req)) - dummy_input_batch.spec_decode_unsupported_reqs = () dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req dummy_input_batch.token_ids_cpu = np.random.randint( 0, 20, (args.num_req, args.num_token) @@ -151,10 +150,8 @@ def benchmark_batched_propose(args): start = time.time() runner.drafter.propose( sampled_token_ids, - dummy_input_batch.req_ids, dummy_input_batch.num_tokens_no_spec, dummy_input_batch.token_ids_cpu, - dummy_input_batch.spec_decode_unsupported_reqs, ) end = time.time() print(f"Iteration time (s): {end - start}") diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index abb3ce2ef..3c7ed77a8 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -940,27 +940,62 @@ def test_correct_decoded_token_preserves_valid_tokens(): ( "eagle", "meta-llama/Llama-3.2-1B-Instruct", - "nm-testing/Llama3_2_1B_speculator.eagle3", + { + "method": "eagle", + "model": "nm-testing/Llama3_2_1B_speculator.eagle3", + "num_speculative_tokens": 3, + }, + 0, ), marks=large_gpu_mark(min_gb=32), + id="eagle0", + ), + pytest.param( + ( + "eagle", + "meta-llama/Llama-3.2-1B-Instruct", + { + "method": "eagle", + "model": "nm-testing/Llama3_2_1B_speculator.eagle3", + "num_speculative_tokens": 3, + }, + 3, + ), + marks=large_gpu_mark(min_gb=32), + id="eagle3", + ), + pytest.param( + ( + "ngram", + "meta-llama/Llama-3.2-1B-Instruct", + { + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": 3, + }, + 3, + ), + marks=large_gpu_mark(min_gb=32), + id="ngram", ), ], ) -@pytest.mark.parametrize("top_logprobs", [0, 3]) def test_spec_decode_logprobs( logprobs_mode: LogprobsMode, - model_setup: tuple[str, str, str], - top_logprobs: int, + model_setup: tuple[str, str, dict, int], ): """Spec decode logprobs should match those of the base model. Args: logprobs_mode: logprobs mode. - model_setup: Spec decode method, base model name, and - draft model name. + model_setup: Tuple of (method, base model name, + speculative_config dict, top_logprobs). """ from vllm import LLM + method, model_name, spec_config, top_logprobs = model_setup + prompt = "Hello world " * 50 sampling_params = SamplingParams( temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False @@ -972,7 +1007,7 @@ def test_spec_decode_logprobs( ignore_eos=False, presence_penalty=-1.0, ) - method, model_name, spec_model_name = model_setup + max_model_len = 256 # Run base LLM. @@ -999,14 +1034,11 @@ def test_spec_decode_logprobs( cleanup_dist_env_and_memory() # Run spec decode LLM. + # Add max_model_len to spec_config if not present + spec_config_with_len = {**spec_config, "max_model_len": max_model_len} spec_llm = LLM( model_name, - speculative_config={ - "method": method, - "model": spec_model_name, - "num_speculative_tokens": 3, - "max_model_len": max_model_len, - }, + speculative_config=spec_config_with_len, max_logprobs=5, max_model_len=max_model_len, seed=42, diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 6bc412abe..7d2a07ddc 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -82,10 +82,8 @@ def test_ngram_proposer(): token_ids_cpu = np.array([[1, 2, 3, 4, 5]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( sampled_token_ids=[[0]], - req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), ) assert len(result[0]) == 0 @@ -93,10 +91,8 @@ def test_ngram_proposer(): token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose( sampled_token_ids=[[0]], - req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), ) assert len(result[0]) == 0 @@ -104,10 +100,8 @@ def test_ngram_proposer(): token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( sampled_token_ids=[[0]], - req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), ) assert np.array_equal(result, np.array([[4, 1]])) @@ -116,10 +110,8 @@ def test_ngram_proposer(): token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( sampled_token_ids=[[0]], - req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), ) assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]] @@ -127,10 +119,8 @@ def test_ngram_proposer(): token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose( sampled_token_ids=[[0]], - req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), ) assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]] @@ -138,10 +128,8 @@ def test_ngram_proposer(): token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose( sampled_token_ids=[[0]], - req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), ) assert np.array_equal(result, np.array([[100, 1]])) @@ -149,10 +137,8 @@ def test_ngram_proposer(): token_ids_cpu = np.array([[]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( sampled_token_ids=[[0]], - req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), ) assert len(result[0]) == 0 @@ -162,10 +148,8 @@ def test_ngram_proposer(): token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( sampled_token_ids=[[0], [1]], - req_ids=["0", "1"], num_tokens_no_spec=np.array([5, 3]), token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), ) assert len(result[0]) == 2 assert np.array_equal(result[0], np.array([3, 1])) @@ -183,10 +167,8 @@ def test_ngram_proposer(): sampled_token_ids = [[2], [], [8]] # Empty list for request 1 simulates prefill result = proposer.propose( sampled_token_ids=sampled_token_ids, - req_ids=["0", "1", "2"], num_tokens_no_spec=num_tokens_no_spec, token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), ) assert len(result) == 3 assert np.array_equal(result[0], [3, 1]) @@ -214,10 +196,8 @@ def test_ngram_proposer(): token_ids_cpu = np.array([input_1, input_2]) result = ngram_proposer.propose( sampled_token_ids=[[0], [1]], - req_ids=["0", "1"], num_tokens_no_spec=np.array([len(input_1), 3]), token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), ) assert len(result[0]) == 2 assert np.array_equal(result[0], np.array([middle_integer + 2, middle_integer + 3])) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 1273ca12c..f97d54e63 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -55,10 +55,8 @@ class NgramProposer: # This usually takes less than 1 second. self.propose( [[]] * 1024, - [""] * 1024, np.zeros(1024, dtype=np.int32), np.zeros((1024, self.max_model_len), dtype=np.int32), - set(), ) def batch_propose( @@ -132,10 +130,8 @@ class NgramProposer: def propose( self, sampled_token_ids: list[list[int]], - req_ids: list[str], num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, - spec_decode_unsupported_reqs: set, ) -> list[list[int]]: # find which requests need ngram proposals valid_ngram_requests = [] @@ -145,12 +141,6 @@ class NgramProposer: # Skip speculative decoding. continue - # Skip requests that require sampling parameters that are not - # supported with speculative decoding. - req_id = req_ids[i] - if req_id in spec_decode_unsupported_reqs: - continue - num_tokens = num_tokens_no_spec[i] if num_tokens >= self.max_model_len: # Skip requests that have already reached the max model length. diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index 049e335db..5d6dcc552 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -46,13 +46,7 @@ class SuffixDecodingProposer: draft_token_ids.append([]) continue - # Skip requests that require sampling parameters that are not - # supported with speculative decoding. req_id = input_batch.req_ids[i] - if req_id in input_batch.spec_decode_unsupported_reqs: - draft_token_ids.append([]) - continue - num_tokens = input_batch.num_tokens_no_spec[i] if num_tokens >= self.max_model_len: # Skip requests that have already reached the max model length. diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 783b6ed59..524714db3 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -1,21 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.sampling_params import SamplingParams from vllm.triton_utils import tl, triton -_SAMPLING_EPS = 1e-5 - - -def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool: - """True if request is incompatible with speculative decoding""" - return ( - sampling_params.frequency_penalty != 0.0 - or sampling_params.presence_penalty != 0.0 - or sampling_params.repetition_penalty != 1.0 - or sampling_params.min_p > _SAMPLING_EPS - or sampling_params.logprobs is not None - ) - @triton.jit def eagle_prepare_inputs_padded_kernel( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 944465224..662badeb5 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -22,7 +22,6 @@ from vllm.v1.sample.logits_processor import ( MoveDirectionality, ) from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice from vllm.v1.worker.block_table import MultiGroupBlockTable @@ -176,9 +175,6 @@ class InputBatch: self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: set[str] = set() - # IDs of requests which do not support spec decoding - self.spec_decode_unsupported_reqs: set[str] = set() - # Frequency penalty related data structures self.frequency_penalties = torch.empty( (max_num_reqs,), dtype=torch.float, device=device @@ -346,8 +342,6 @@ class InputBatch: self.block_table.add_row(request.block_ids, req_index) if sampling_params := request.sampling_params: - if self.is_spec_decode and is_spec_decode_unsupported(sampling_params): - self.spec_decode_unsupported_reqs.add(req_id) if sampling_params.sampling_type == SamplingType.GREEDY: # Should avoid division by zero later when apply_temperature. self.temperature_cpu[req_index] = 0.0 @@ -510,7 +504,6 @@ class InputBatch: self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) self.top_k_reqs.discard(req_id) - self.spec_decode_unsupported_reqs.discard(req_id) self.frequency_penalties_reqs.discard(req_id) self.presence_penalties_reqs.discard(req_id) self.repetition_penalties_reqs.discard(req_id) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bd57708ab..5228167ed 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3620,10 +3620,8 @@ class GPUModelRunner( assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.drafter.propose( sampled_token_ids, - self.input_batch.req_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, - self.input_batch.spec_decode_unsupported_reqs, ) elif spec_config.method == "suffix": assert isinstance(sampled_token_ids, list)