[V1][Spec Decode] Ngram Spec Decode (#12193)

Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
Lily Liu
2025-02-15 18:05:11 -08:00
committed by GitHub
parent 367cb8ce8c
commit 80f63a3966
21 changed files with 1023 additions and 82 deletions

View File

@@ -92,6 +92,7 @@ def _construct_expected_sampling_metadata(
device=device),
all_greedy=False,
all_random=True,
rejection_sampling=False,
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
top_k=torch.tensor(top_k, dtype=torch.int, device=device),
no_top_p=all(x == 1.0 for x in top_p),
@@ -116,6 +117,7 @@ def _construct_expected_sampling_metadata(
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
spec_token_ids=[],
min_tokens=min_tokens,
stop_token_ids=stop_token_ids,
no_penalties=(all(x == 0 for x in presence_penalties)
@@ -205,7 +207,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
# Generate the sampling metadata
sampling_metadata = input_batch.make_sampling_metadata(
req_id_output_token_ids, skip_copy=False)
req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False)
# Create expected output.
expected_sampling_metadata = _construct_expected_sampling_metadata(