[Hardware][Neuron] Refactor neuron support (#3471)
This commit is contained in:
@@ -71,7 +71,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker.init_model()
|
||||
worker.init_device()
|
||||
|
||||
vocab_size = 32_000
|
||||
|
||||
@@ -151,7 +151,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker.init_model()
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
@@ -230,7 +230,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker.init_model()
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
@@ -342,7 +342,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker.init_model()
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
@@ -486,8 +486,8 @@ def test_empty_input_batch(k: int, batch_size: int):
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_init_model():
|
||||
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_model, as
|
||||
def test_init_device():
|
||||
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
||||
well as other GPU initialization.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
@@ -499,11 +499,11 @@ def test_init_model():
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
|
||||
worker.init_model()
|
||||
worker.init_device()
|
||||
|
||||
draft_worker.init_model.assert_called_once()
|
||||
draft_worker.init_device.assert_called_once()
|
||||
|
||||
target_worker.init_model.assert_called_once()
|
||||
target_worker.init_device.assert_called_once()
|
||||
|
||||
metrics_collector.init_gpu_tensors.assert_called_once()
|
||||
rejection_sampler.init_gpu_tensors.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user