Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,77 +4,75 @@ import numpy as np
|
||||
|
||||
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.v1.spec_decode.ngram_proposer import (
|
||||
NgramProposer, _find_longest_matched_ngram_and_propose_tokens)
|
||||
NgramProposer,
|
||||
_find_longest_matched_ngram_and_propose_tokens,
|
||||
)
|
||||
|
||||
|
||||
def test_find_longest_matched_ngram_and_propose_tokens():
|
||||
tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
|
||||
result = _find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens,
|
||||
min_ngram=2,
|
||||
max_ngram=2,
|
||||
max_model_len=1024,
|
||||
k=2)
|
||||
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=2
|
||||
)
|
||||
assert len(result) == 0
|
||||
|
||||
tokens = np.array([1, 2, 3, 4, 1, 2, 3])
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||
min_ngram=2,
|
||||
max_ngram=2,
|
||||
max_model_len=1024,
|
||||
k=3),
|
||||
np.array([4, 1, 2]))
|
||||
_find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=3
|
||||
),
|
||||
np.array([4, 1, 2]),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||
min_ngram=2,
|
||||
max_ngram=2,
|
||||
max_model_len=1024,
|
||||
k=2), np.array([4, 1]))
|
||||
_find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=2
|
||||
),
|
||||
np.array([4, 1]),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||
min_ngram=1,
|
||||
max_ngram=1,
|
||||
max_model_len=1024,
|
||||
k=3),
|
||||
np.array([4, 1, 2]))
|
||||
_find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=3
|
||||
),
|
||||
np.array([4, 1, 2]),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||
min_ngram=1,
|
||||
max_ngram=1,
|
||||
max_model_len=1024,
|
||||
k=2), np.array([4, 1]))
|
||||
_find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=2
|
||||
),
|
||||
np.array([4, 1]),
|
||||
)
|
||||
|
||||
tokens = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||
min_ngram=2,
|
||||
max_ngram=2,
|
||||
max_model_len=1024,
|
||||
k=3),
|
||||
np.array([4, 1, 2]))
|
||||
_find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=3
|
||||
),
|
||||
np.array([4, 1, 2]),
|
||||
)
|
||||
# Return on the first match
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||
min_ngram=1,
|
||||
max_ngram=1,
|
||||
max_model_len=1024,
|
||||
k=2), np.array([6, 2]))
|
||||
_find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=2
|
||||
),
|
||||
np.array([6, 2]),
|
||||
)
|
||||
|
||||
|
||||
def test_ngram_proposer():
|
||||
|
||||
def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
|
||||
# Dummy model config. Just to set max_model_len.
|
||||
model_config = ModelConfig(model="facebook/opt-125m")
|
||||
return NgramProposer(
|
||||
vllm_config=VllmConfig(model_config=model_config,
|
||||
speculative_config=SpeculativeConfig(
|
||||
prompt_lookup_min=min_n,
|
||||
prompt_lookup_max=max_n,
|
||||
num_speculative_tokens=k,
|
||||
method="ngram",
|
||||
)))
|
||||
vllm_config=VllmConfig(
|
||||
model_config=model_config,
|
||||
speculative_config=SpeculativeConfig(
|
||||
prompt_lookup_min=min_n,
|
||||
prompt_lookup_max=max_n,
|
||||
num_speculative_tokens=k,
|
||||
method="ngram",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# No match.
|
||||
token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
|
||||
@@ -133,8 +131,7 @@ def test_ngram_proposer():
|
||||
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]]
|
||||
|
||||
# Multiple 3-gram matched, but always pick the first one.
|
||||
token_ids_cpu = np.array(
|
||||
[[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
|
||||
token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
|
||||
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
@@ -191,6 +188,5 @@ def test_ngram_proposer():
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 2
|
||||
assert np.array_equal(result[0],
|
||||
np.array([middle_integer + 2, middle_integer + 3]))
|
||||
assert np.array_equal(result[0], np.array([middle_integer + 2, middle_integer + 3]))
|
||||
assert np.array_equal(result[1], np.array([]))
|
||||
|
||||
Reference in New Issue
Block a user