Re-enable the 80 char line width limit (#3305)
This commit is contained in:
@@ -5,8 +5,12 @@ import torch
|
||||
|
||||
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, SequenceData)
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.spec_decode.util import nvtx_range, sampler_output_to_torch, get_all_seq_ids, split_batch_by_proposal_len
|
||||
from vllm.spec_decode.interfaces import SpeculativeScorer, SpeculativeProposals, SpeculativeScores
|
||||
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
|
||||
get_all_seq_ids,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeScorer,
|
||||
SpeculativeProposals,
|
||||
SpeculativeScores)
|
||||
|
||||
SeqId = int
|
||||
TargetSeqId = int
|
||||
@@ -68,11 +72,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
|
||||
|
||||
spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens = self._expand_batch(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list,
|
||||
proposal_lens_list=proposal_lens_list,
|
||||
)
|
||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens) = self._expand_batch(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list,
|
||||
proposal_lens_list=proposal_lens_list,
|
||||
)
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list,
|
||||
@@ -125,7 +130,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
num_scoring_tokens = len(target_seq_group_metadata_list)
|
||||
target_seq_group_metadata_list.extend(non_spec_seqs)
|
||||
|
||||
return spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens
|
||||
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens)
|
||||
|
||||
def _contract_batch(self, original_bs: int,
|
||||
target_sampler_output: List[SamplerOutput],
|
||||
@@ -306,10 +312,11 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
# Convert non-speculative output tokens to tensors.
|
||||
sampler_output.sampled_token_probs = non_spec_probs
|
||||
sampler_output.sampled_token_ids = non_spec_sampled_tokens
|
||||
non_spec_target_token_ids, non_spec_target_probs = sampler_output_to_torch(
|
||||
[sampler_output])
|
||||
non_spec_target_token_ids, non_spec_target_probs = (
|
||||
sampler_output_to_torch([sampler_output]))
|
||||
|
||||
return target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs
|
||||
return (target_token_ids, target_probs, non_spec_target_token_ids,
|
||||
non_spec_target_probs)
|
||||
|
||||
def _create_target_seq_id_iterator(
|
||||
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
|
||||
|
||||
@@ -5,7 +5,8 @@ import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
|
||||
|
||||
@@ -247,8 +248,9 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
"""
|
||||
|
||||
# Split speculative- and non-speculative- sequences.
|
||||
proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices = self._split_by_max_model_len(
|
||||
seq_group_metadata_list, max_proposal_len)
|
||||
(proposal_lens, nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices) = self._split_by_max_model_len(
|
||||
seq_group_metadata_list, max_proposal_len)
|
||||
|
||||
if nonzero_proposal_len_seqs:
|
||||
# Speculate tokens using the draft worker for the speculative
|
||||
@@ -306,7 +308,8 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
else:
|
||||
proposal_lens.append(0)
|
||||
|
||||
return proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices
|
||||
return (proposal_lens, nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices)
|
||||
|
||||
def _merge_outputs(
|
||||
self,
|
||||
@@ -356,7 +359,8 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
device=self._device)
|
||||
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||
|
||||
proposal_tokens, proposal_probs = entire_proposal_tokens, entire_proposal_probs
|
||||
proposal_tokens, proposal_probs = (entire_proposal_tokens,
|
||||
entire_proposal_probs)
|
||||
|
||||
proposal_lens = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
|
||||
@@ -10,7 +10,8 @@ from vllm.worker.worker import Worker
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.spec_decode.util import nvtx_range, get_all_seq_ids, split_batch_by_proposal_len
|
||||
from vllm.spec_decode.util import (nvtx_range, get_all_seq_ids,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import SpeculativeScorer
|
||||
@@ -25,7 +26,7 @@ class SpecDecodeWorker:
|
||||
LLM, after which some verification routine determines which (if any) of the
|
||||
speculative tokens are accepted by the larger LLM.
|
||||
|
||||
See https://github.com/vllm-project/vllm/pull/2188 and
|
||||
See https://github.com/vllm-project/vllm/pull/2188 and
|
||||
https://github.com/vllm-project/vllm/pull/3103 for more info.
|
||||
|
||||
The current implementation has the following limitations:
|
||||
@@ -109,10 +110,12 @@ class SpecDecodeWorker:
|
||||
block_size, gpu_memory_utilization, cpu_swap_space,
|
||||
cache_dtype))
|
||||
|
||||
scorer_cache_block_size_bytes = self.scorer_worker.get_cache_block_size_bytes(
|
||||
block_size, cache_dtype)
|
||||
proposer_cache_block_size_bytes = self.proposer_worker.get_cache_block_size_bytes(
|
||||
block_size, cache_dtype)
|
||||
scorer_cache_block_size_bytes = (
|
||||
self.scorer_worker.get_cache_block_size_bytes(
|
||||
block_size, cache_dtype))
|
||||
proposer_cache_block_size_bytes = (
|
||||
self.proposer_worker.get_cache_block_size_bytes(
|
||||
block_size, cache_dtype))
|
||||
|
||||
new_num_gpu_blocks = split_num_cache_blocks_evenly(
|
||||
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
|
||||
@@ -320,8 +323,8 @@ class SpecDecodeWorker:
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(outputs=step_output_token_ids))
|
||||
|
||||
maybe_rejsample_metrics = self._metrics.maybe_collect_rejsample_metrics(
|
||||
k)
|
||||
maybe_rejsample_metrics = (
|
||||
self._metrics.maybe_collect_rejsample_metrics(k))
|
||||
if maybe_rejsample_metrics is not None:
|
||||
sampler_output_list[
|
||||
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
||||
|
||||
Reference in New Issue
Block a user