[Spec Decode] Disable Log Prob serialization to CPU for spec decoding for both draft and target models. (#6485)
This commit is contained in:
@@ -381,6 +381,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
@@ -479,7 +480,8 @@ def test_k_equals_zero(k: int, batch_size: int,
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), False,
|
||||
metrics_collector)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
@@ -490,9 +492,10 @@ def test_k_equals_zero(k: int, batch_size: int,
|
||||
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||
assert out[0].sampled_token_probs is None, (
|
||||
"expect gpu tensor references to be None")
|
||||
assert out[
|
||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||
0].sampled_token_ids is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
@@ -524,7 +527,8 @@ def test_empty_input_batch(k: int, batch_size: int,
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), False,
|
||||
metrics_collector)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
@@ -535,9 +539,10 @@ def test_empty_input_batch(k: int, batch_size: int,
|
||||
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||
assert out[0].sampled_token_probs is None, (
|
||||
"expect gpu tensor references to be None")
|
||||
assert out[
|
||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||
0].sampled_token_ids is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
@@ -556,7 +561,7 @@ def test_init_device(acceptance_sampler_method: str):
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
False, metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
draft_worker.init_device.assert_called_once()
|
||||
@@ -707,6 +712,7 @@ def test_populate_seq_ids_with_bonus_tokens():
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
|
||||
# This set includes all sequence IDs in the batch as well as an additional
|
||||
|
||||
Reference in New Issue
Block a user