[V1][Spec Decode] Ngram Spec Decode (#12193)
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
@@ -92,6 +92,7 @@ def _construct_expected_sampling_metadata(
|
||||
device=device),
|
||||
all_greedy=False,
|
||||
all_random=True,
|
||||
rejection_sampling=False,
|
||||
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
|
||||
top_k=torch.tensor(top_k, dtype=torch.int, device=device),
|
||||
no_top_p=all(x == 1.0 for x in top_p),
|
||||
@@ -116,6 +117,7 @@ def _construct_expected_sampling_metadata(
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
output_token_ids=output_token_ids,
|
||||
spec_token_ids=[],
|
||||
min_tokens=min_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
no_penalties=(all(x == 0 for x in presence_penalties)
|
||||
@@ -205,7 +207,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
|
||||
# Generate the sampling metadata
|
||||
sampling_metadata = input_batch.make_sampling_metadata(
|
||||
req_id_output_token_ids, skip_copy=False)
|
||||
req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False)
|
||||
|
||||
# Create expected output.
|
||||
expected_sampling_metadata = _construct_expected_sampling_metadata(
|
||||
|
||||
@@ -66,6 +66,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
@@ -109,6 +110,7 @@ def test_update_states_request_finished(model_runner):
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids={req_id},
|
||||
@@ -137,6 +139,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids={},
|
||||
@@ -160,6 +163,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
scheduled_cached_reqs=[cached_req_data],
|
||||
num_scheduled_tokens={req_id: 1},
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
@@ -188,6 +192,7 @@ def test_update_states_no_changes(model_runner):
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={req_id: 1},
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
@@ -220,6 +225,7 @@ def test_update_states_request_unscheduled(model_runner):
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={req_ids[0]: 1},
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
|
||||
Reference in New Issue
Block a user