[Feature][Spec Decode] Simplify the use of Eagle Spec Decode (#12304)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -13,15 +13,18 @@ from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceOutput
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
||||
split_num_cache_blocks_evenly)
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
from .test_utils import mock_spec_decode_sampler
|
||||
from .utils import create_batch, create_sampler_output_list, mock_worker
|
||||
from .utils import (create_batch, create_sampler_output_list, create_worker,
|
||||
mock_worker)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@@ -905,3 +908,38 @@ def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# but first draft still counted
|
||||
assert draft_worker.get_spec_proposals.call_count == 1
|
||||
|
||||
|
||||
def test_correctly_load_weight_for_eagle():
|
||||
"""
|
||||
Verify SpecDecodeWorker loads lm_head weight for eagle correctly.
|
||||
"""
|
||||
seed = 100
|
||||
block_size = 32
|
||||
num_gpu_blocks = 8096 // block_size
|
||||
target_worker = create_worker(
|
||||
Worker,
|
||||
"JackFram/llama-68m",
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
draft_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
"abhigoyal/vllm-eagle-llama-68m-random",
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
|
||||
spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler")
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False)
|
||||
worker.proposer_worker.maybe_load_lm_head_weight(
|
||||
target_worker.model_runner.model.lm_head.weight.data)
|
||||
assert torch.allclose(
|
||||
worker.proposer_worker.worker.model_runner.model.lm_head.weight.data,
|
||||
worker.scorer_worker.model_runner.model.lm_head.weight.data)
|
||||
|
||||
Reference in New Issue
Block a user