Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
import dataclasses
|
||||
from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -170,7 +169,7 @@ def test_schedule_partial_requests():
|
||||
req_id_to_index=req_to_index,
|
||||
# Only the first request has a sampled token id because
|
||||
# the rest requests are still being prefilled.
|
||||
sampled_token_ids=[np.array([0]), np.array([]), np.array([])],
|
||||
sampled_token_ids=[[0], [], []],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -217,7 +216,7 @@ def test_no_mm_input_chunking():
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([]) for _ in range(len(requests))],
|
||||
sampled_token_ids=[[] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -277,7 +276,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([]) for _ in range(len(requests))],
|
||||
sampled_token_ids=[[] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -301,8 +300,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([0]), np.array([0])]
|
||||
+ [np.array([]) for _ in range(len(requests) - 2)],
|
||||
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -349,8 +347,8 @@ def test_stop_via_update_from_output():
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
||||
sampled_token_ids=[
|
||||
np.array([EOS_TOKEN_ID]),
|
||||
np.array([10, 11]),
|
||||
[EOS_TOKEN_ID],
|
||||
[10, 11],
|
||||
], # First request hits EOS, second continues
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
@@ -394,10 +392,7 @@ def test_stop_via_update_from_output():
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
||||
sampled_token_ids=[
|
||||
np.array([10, 42, 12]),
|
||||
np.array([13, 14]),
|
||||
], # First request hits stop token
|
||||
sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -441,10 +436,7 @@ def test_stop_via_update_from_output():
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
||||
sampled_token_ids=[
|
||||
np.array([10, 11, 12]),
|
||||
np.array([13]),
|
||||
], # First request exceeds max_tokens
|
||||
sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -483,7 +475,7 @@ def test_stop_via_update_from_output():
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[np.array([EOS_TOKEN_ID, 10, 11])],
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -624,7 +616,7 @@ def test_schedule_concurrent_batches(
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[np.array([0])],
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -641,7 +633,7 @@ def test_schedule_concurrent_batches(
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[np.array([0])],
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -678,7 +670,7 @@ def test_preempt_during_execution():
|
||||
model_runner_output0 = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[np.array([0])],
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -695,7 +687,7 @@ def test_preempt_during_execution():
|
||||
model_runner_output1 = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[np.array([42])],
|
||||
sampled_token_ids=[[42]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -712,18 +704,14 @@ def test_preempt_during_execution():
|
||||
@pytest.mark.parametrize(
|
||||
"spec_tokens,output_tokens,expected",
|
||||
[
|
||||
([[1, 2, 3]], [np.array([1, 2, 3, 4])], (1, 3, 3, [1, 1, 1])), # perfect match
|
||||
([[1, 2, 3]], [np.array([1, 5])], (1, 3, 1, [1, 0, 0])), # early mismatch
|
||||
(
|
||||
[[1, 2], [3]],
|
||||
[np.array([1, 2, 5]), np.array([3, 4])],
|
||||
(2, 3, 3, [2, 1]),
|
||||
), # multiple sequences
|
||||
([[1]], [np.array([1, 2])], (1, 1, 1, [1])), # single token sequence
|
||||
([[]], [np.array([5])], (0, 0, 0, [0])), # empty sequence
|
||||
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
|
||||
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
|
||||
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences
|
||||
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
|
||||
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
|
||||
(
|
||||
[[1, 2, 3], [4, 5, 6]],
|
||||
[np.array([1, 2, 7]), np.array([4, 8])],
|
||||
[[1, 2, 7], [4, 8]],
|
||||
(2, 6, 3, [2, 1, 0]),
|
||||
), # multiple mismatches
|
||||
],
|
||||
@@ -757,7 +745,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([0]) for _ in range(len(requests))],
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -984,7 +972,7 @@ def test_kv_connector_basic(is_async: bool):
|
||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([1000])] * len(req_ids),
|
||||
sampled_token_ids=[[1000]] * len(req_ids),
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -1037,7 +1025,7 @@ def test_kv_connector_basic(is_async: bool):
|
||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([1000])] * len(req_ids),
|
||||
sampled_token_ids=[[1000]] * len(req_ids),
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -1100,7 +1088,7 @@ def test_external_prefix_cache_metrics():
|
||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=[r.request_id for r in requests],
|
||||
req_id_to_index={r.request_id: i for i, r in enumerate(requests)},
|
||||
sampled_token_ids=[np.array([1000])] * NUM_REQUESTS,
|
||||
sampled_token_ids=[[1000]] * NUM_REQUESTS,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -1166,7 +1154,7 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
|
||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([1000])] * len(req_ids),
|
||||
sampled_token_ids=[[1000]] * len(req_ids),
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -1251,7 +1239,7 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
|
||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([1000])] * len(req_ids),
|
||||
sampled_token_ids=[[1000]] * len(req_ids),
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -1344,7 +1332,7 @@ def make_output(scheduler: Scheduler):
|
||||
return ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in scheduler.running],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)},
|
||||
sampled_token_ids=[np.array([1000])] * len(scheduler.running),
|
||||
sampled_token_ids=[[1000]] * len(scheduler.running),
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -1761,7 +1749,7 @@ def test_priority_scheduling_preemption():
|
||||
req_id_to_index={
|
||||
req.request_id: i for i, req in enumerate(low_priority_requests)
|
||||
},
|
||||
sampled_token_ids=[np.array([100]) for _ in low_priority_requests],
|
||||
sampled_token_ids=[[100] for _ in low_priority_requests],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -1830,7 +1818,7 @@ def test_priority_scheduling_no_preemption_when_space_available():
|
||||
req_id_to_index={
|
||||
req.request_id: i for i, req in enumerate(low_priority_requests)
|
||||
},
|
||||
sampled_token_ids=[np.array([100]) for _ in low_priority_requests],
|
||||
sampled_token_ids=[[100] for _ in low_priority_requests],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -2076,7 +2064,7 @@ def test_priority_scheduling_heap_property():
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.req_id],
|
||||
req_id_to_index={req.req_id: 0},
|
||||
sampled_token_ids=[np.array([100])],
|
||||
sampled_token_ids=[[100]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -2162,7 +2150,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[request_low.request_id],
|
||||
req_id_to_index={request_low.request_id: 0},
|
||||
sampled_token_ids=[np.array([100])],
|
||||
sampled_token_ids=[[100]],
|
||||
# spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
@@ -2193,7 +2181,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
||||
sampled_token_ids=[np.array([100]) for _ in requests],
|
||||
sampled_token_ids=[[100] for _ in requests],
|
||||
# spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
@@ -2219,7 +2207,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
||||
sampled_token_ids=[np.array([]), np.array([100])],
|
||||
sampled_token_ids=[[], [100]],
|
||||
# spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
@@ -2636,7 +2624,7 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[request1.request_id],
|
||||
req_id_to_index={request1.request_id: 0},
|
||||
sampled_token_ids=[np.array([100])],
|
||||
sampled_token_ids=[[100]],
|
||||
# spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
@@ -2842,7 +2830,7 @@ def test_ec_connector_unable_to_allocate(use_kv_connector):
|
||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([1000])] * len(req_ids),
|
||||
sampled_token_ids=[[1000]] * len(req_ids),
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@@ -2955,7 +2943,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[request_low.request_id],
|
||||
req_id_to_index={request_low.request_id: 0},
|
||||
sampled_token_ids=[np.array([100])],
|
||||
sampled_token_ids=[[100]],
|
||||
# spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
@@ -3006,7 +2994,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
||||
sampled_token_ids=[np.array([100]) for _ in requests],
|
||||
sampled_token_ids=[[100] for _ in requests],
|
||||
# spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
@@ -3041,7 +3029,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
||||
sampled_token_ids=[np.array([100]), np.array([100, 200])],
|
||||
sampled_token_ids=[[100], [100, 200]],
|
||||
# spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
@@ -3227,7 +3215,7 @@ def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connecto
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[request1.request_id, request2.request_id],
|
||||
req_id_to_index={request1.request_id: 0, request2.request_id: 1},
|
||||
sampled_token_ids=[np.array([100]), np.array([121])],
|
||||
sampled_token_ids=[[100], [121]],
|
||||
# spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
|
||||
Reference in New Issue
Block a user