[Speculative Decoding] Support draft model on different tensor-parallel size than target model (#5414)
This commit is contained in:
@@ -6,7 +6,8 @@ import torch
|
||||
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
@@ -28,9 +29,9 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Lazy initialization list.
|
||||
self._proposer: Top1Proposer
|
||||
self._proposer: SpeculativeProposer
|
||||
|
||||
def init_device(self):
|
||||
def init_device(self) -> None:
|
||||
super().init_device()
|
||||
|
||||
self._proposer = Top1Proposer(
|
||||
@@ -40,7 +41,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
max_proposal_len=self.max_model_len,
|
||||
)
|
||||
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
def set_include_gpu_probs_tensor(self) -> None:
|
||||
# Need include_gpu_probs_tensor for multi_step_worker
|
||||
self.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||
|
||||
@@ -73,7 +74,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
for _ in range(sample_len):
|
||||
model_output = super().execute_model(
|
||||
model_output: List[SamplerOutput] = super().execute_model(
|
||||
execute_model_req=copied_execute_model_req)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
|
||||
@@ -3,10 +3,10 @@ from typing import List, Optional, Tuple
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposer
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
|
||||
|
||||
class ProposerWorkerBase(WorkerBase, SpeculativeProposer):
|
||||
class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
|
||||
"""Interface for proposer workers"""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
149
vllm/spec_decode/smaller_tp_proposer_worker.py
Normal file
149
vllm/spec_decode/smaller_tp_proposer_worker.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import (get_tp_group,
|
||||
init_model_parallel_group,
|
||||
patch_tensor_parallel_group)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SmallerTpProposerWorker(ProposerWorkerBase):
|
||||
"""Class which allows a speculative draft model to run with smaller tensor
|
||||
parallel degree than target model.
|
||||
This reduces the communication overhead of small draft models.
|
||||
|
||||
To implement this feature, this class differs behavior based on is_dummy
|
||||
flag, where dummy means worker that does not participate draft generation.
|
||||
Participating workers use a smaller tp group by patching vLLM's tensor
|
||||
parallel group temporarily during forward passes of draft models.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def maybe_wrap_worker(cls, worker, draft_tensor_parallel_size: int,
|
||||
target_tensor_parallel_size: int):
|
||||
"""Wrap the worker in a SmallerTpProposerWorker if necessary.
|
||||
"""
|
||||
if draft_tensor_parallel_size == target_tensor_parallel_size:
|
||||
return worker
|
||||
|
||||
# gpu ranks that will generate draft tokens together
|
||||
draft_ranks = list(range(draft_tensor_parallel_size))
|
||||
|
||||
logger.info("Wrapping {%s} in {%s}", type(worker), cls)
|
||||
return cls(worker, draft_ranks)
|
||||
|
||||
def __init__(self, worker: MultiStepWorker, draft_ranks: List[int]):
|
||||
"""Create a SmallerTpProposerWorker.
|
||||
|
||||
Args:
|
||||
worker (MultiStepWorker): an actual worker wrapped with this class
|
||||
draft_ranks (List[int]): if this value is given, only the GPU ranks
|
||||
written in this value participate in draft generation
|
||||
"""
|
||||
self._worker = worker
|
||||
self._draft_ranks = draft_ranks
|
||||
|
||||
# init during init_device
|
||||
self._is_dummy = False
|
||||
self._tp_group = None
|
||||
|
||||
def _patch_tensor_parallel_group(self):
|
||||
"""Temporarily patch the global tp group state with its own tp group
|
||||
state.
|
||||
"""
|
||||
return patch_tensor_parallel_group(self._tp_group)
|
||||
|
||||
def init_device(self) -> None:
|
||||
self._is_dummy = get_tp_group().rank not in self._draft_ranks
|
||||
|
||||
# dummy workers do nothing
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
# creates tp process group containing only a subset of gpu ranks
|
||||
local_rank = get_tp_group().local_rank
|
||||
tp_backend = torch.distributed.get_backend(get_tp_group().device_group)
|
||||
self._tp_group = init_model_parallel_group([self._draft_ranks],
|
||||
local_rank, tp_backend)
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
self._worker.init_device()
|
||||
|
||||
def set_include_gpu_probs_tensor(self) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
# Need include_gpu_probs_tensor for multi_step_worker
|
||||
self._worker.set_include_gpu_probs_tensor()
|
||||
|
||||
def load_model(self) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
self._worker.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
if self._is_dummy:
|
||||
# this case is not used now
|
||||
return -1, -1
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
return self._worker.determine_num_available_blocks()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
self._worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
# Do not check _is_dummy, as it's always called by get_spec_proposals
|
||||
return self._worker.sampler_output(execute_model_req, sample_len)
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
if self._is_dummy:
|
||||
return SpeculativeProposals(None, None, None)
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
return self._worker.get_spec_proposals(execute_model_req)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
if self._is_dummy:
|
||||
return []
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
return self._worker.execute_model(execute_model_req)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
if self._is_dummy:
|
||||
# by returning zero, target worker can use the entire kv cache space
|
||||
return 0
|
||||
|
||||
return self._worker.get_cache_block_size_bytes()
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self._worker.vocab_size
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import SpeculativeConfig
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig
|
||||
from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
@@ -18,6 +18,7 @@ from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
||||
from vllm.spec_decode.util import (create_sequence_group_output,
|
||||
get_all_num_logprobs,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
@@ -90,7 +91,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
@classmethod
|
||||
def create_worker(
|
||||
cls,
|
||||
scorer_worker: WorkerBase,
|
||||
scorer_worker: Worker,
|
||||
draft_worker_kwargs: Dict[str, Any],
|
||||
disable_by_batch_size: Optional[int],
|
||||
) -> "SpecDecodeWorker":
|
||||
@@ -111,7 +112,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
||||
disable_bonus_tokens = False
|
||||
else:
|
||||
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
|
||||
'parallel_config']
|
||||
draft_tp = draft_parallel_config.tensor_parallel_size
|
||||
target_tp = scorer_worker.parallel_config.tensor_parallel_size
|
||||
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||
proposer_worker, draft_tp, target_tp)
|
||||
|
||||
logger.info("Configuring SpecDecodeWorker with proposer=%s",
|
||||
type(proposer_worker))
|
||||
|
||||
Reference in New Issue
Block a user