[Hardware][Neuron] Refactor neuron support (#3471)

This commit is contained in:
Zhuohan Li
2024-03-21 18:22:17 -07:00
committed by GitHub
parent ea5f14e6ff
commit e90fc21f2e
33 changed files with 615 additions and 549 deletions

View File

@@ -71,7 +71,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
vocab_size = 32_000
@@ -151,7 +151,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
@@ -230,7 +230,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
@@ -342,7 +342,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
@@ -486,8 +486,8 @@ def test_empty_input_batch(k: int, batch_size: int):
@torch.inference_mode()
def test_init_model():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_model, as
def test_init_device():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
@@ -499,11 +499,11 @@ def test_init_model():
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
draft_worker.init_model.assert_called_once()
draft_worker.init_device.assert_called_once()
target_worker.init_model.assert_called_once()
target_worker.init_device.assert_called_once()
metrics_collector.init_gpu_tensors.assert_called_once()
rejection_sampler.init_gpu_tensors.assert_called_once()