[Spec Decode] Disable Log Prob serialization to CPU for spec decoding for both draft and target models. (#6485)

This commit is contained in:
sroy745
2024-07-20 23:58:58 -07:00
committed by GitHub
parent d7f4178dd9
commit 14f91fe67c
8 changed files with 333 additions and 64 deletions

View File

@@ -381,6 +381,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()
@@ -479,7 +480,8 @@ def test_k_equals_zero(k: int, batch_size: int,
worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
mock_spec_decode_sampler(acceptance_sampler_method), False,
metrics_collector)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
@@ -490,9 +492,10 @@ def test_k_equals_zero(k: int, batch_size: int,
out = worker.execute_model(execute_model_req=execute_model_req)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[0].sampled_token_probs is None, (
"expect gpu tensor references to be None")
assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None"
0].sampled_token_ids is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
@@ -524,7 +527,8 @@ def test_empty_input_batch(k: int, batch_size: int,
worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
mock_spec_decode_sampler(acceptance_sampler_method), False,
metrics_collector)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
@@ -535,9 +539,10 @@ def test_empty_input_batch(k: int, batch_size: int,
out = worker.execute_model(execute_model_req=execute_model_req)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[0].sampled_token_probs is None, (
"expect gpu tensor references to be None")
assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None"
0].sampled_token_ids is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
@@ -556,7 +561,7 @@ def test_init_device(acceptance_sampler_method: str):
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector)
False, metrics_collector)
worker.init_device()
draft_worker.init_device.assert_called_once()
@@ -707,6 +712,7 @@ def test_populate_seq_ids_with_bonus_tokens():
worker = SpecDecodeWorker(draft_worker,
target_worker,
mock_spec_decode_sampler("rejection_sampler"),
disable_logprobs=False,
metrics_collector=metrics_collector)
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
# This set includes all sequence IDs in the batch as well as an additional