[Model] MLPSpeculator speculative decoding support (#4947)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: Davis Wertheimer <Davis.Wertheimer@ibm.com>
This commit is contained in:
committed by
GitHub
parent
6c5b7af152
commit
b12518d3cf
@@ -456,7 +456,9 @@ def test_k_equals_zero(k: int, batch_size: int):
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
||||
sampler_output = MagicMock(spec=SamplerOutput)
|
||||
sampler_output.hidden_states = None
|
||||
target_worker.execute_model.return_value = [sampler_output]
|
||||
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
@@ -497,7 +499,9 @@ def test_empty_input_batch(k: int, batch_size: int):
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
||||
sampler_output = MagicMock(spec=SamplerOutput)
|
||||
sampler_output.hidden_states = None
|
||||
target_worker.execute_model.return_value = [sampler_output]
|
||||
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
Reference in New Issue
Block a user