Re-enable the 80 char line width limit (#3305)

This commit is contained in:
Zhuohan Li
2024-03-10 19:49:14 -07:00
committed by GitHub
parent 4b59f00e91
commit 2f8844ba08
67 changed files with 557 additions and 528 deletions

View File

@@ -5,8 +5,12 @@ import torch
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, SequenceData)
from vllm.worker.worker import Worker
from vllm.spec_decode.util import nvtx_range, sampler_output_to_torch, get_all_seq_ids, split_batch_by_proposal_len
from vllm.spec_decode.interfaces import SpeculativeScorer, SpeculativeProposals, SpeculativeScores
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
get_all_seq_ids,
split_batch_by_proposal_len)
from vllm.spec_decode.interfaces import (SpeculativeScorer,
SpeculativeProposals,
SpeculativeScores)
SeqId = int
TargetSeqId = int
@@ -68,11 +72,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
proposal_lens_list = proposals.proposal_lens.tolist()
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list,
proposal_lens_list=proposal_lens_list,
)
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list,
proposal_lens_list=proposal_lens_list,
)
target_sampler_output = self._scorer_worker.execute_model(
seq_group_metadata_list=target_seq_group_metadata_list,
@@ -125,7 +130,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
num_scoring_tokens = len(target_seq_group_metadata_list)
target_seq_group_metadata_list.extend(non_spec_seqs)
return spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens)
def _contract_batch(self, original_bs: int,
target_sampler_output: List[SamplerOutput],
@@ -306,10 +312,11 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
non_spec_target_token_ids, non_spec_target_probs = sampler_output_to_torch(
[sampler_output])
non_spec_target_token_ids, non_spec_target_probs = (
sampler_output_to_torch([sampler_output]))
return target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs
return (target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs)
def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:

View File

@@ -5,7 +5,8 @@ import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.worker import Worker
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.util import sampler_output_to_torch
@@ -247,8 +248,9 @@ class DraftModelTop1Proposer(SpeculativeProposer):
"""
# Split speculative- and non-speculative- sequences.
proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices = self._split_by_max_model_len(
seq_group_metadata_list, max_proposal_len)
(proposal_lens, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices) = self._split_by_max_model_len(
seq_group_metadata_list, max_proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
@@ -306,7 +308,8 @@ class DraftModelTop1Proposer(SpeculativeProposer):
else:
proposal_lens.append(0)
return proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices
return (proposal_lens, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices)
def _merge_outputs(
self,
@@ -356,7 +359,8 @@ class DraftModelTop1Proposer(SpeculativeProposer):
device=self._device)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = entire_proposal_tokens, entire_proposal_probs
proposal_tokens, proposal_probs = (entire_proposal_tokens,
entire_proposal_probs)
proposal_lens = torch.zeros(batch_size,
dtype=torch.long,

View File

@@ -10,7 +10,8 @@ from vllm.worker.worker import Worker
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.config import CacheConfig
from vllm.spec_decode.util import nvtx_range, get_all_seq_ids, split_batch_by_proposal_len
from vllm.spec_decode.util import (nvtx_range, get_all_seq_ids,
split_batch_by_proposal_len)
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeScorer
@@ -25,7 +26,7 @@ class SpecDecodeWorker:
LLM, after which some verification routine determines which (if any) of the
speculative tokens are accepted by the larger LLM.
See https://github.com/vllm-project/vllm/pull/2188 and
See https://github.com/vllm-project/vllm/pull/2188 and
https://github.com/vllm-project/vllm/pull/3103 for more info.
The current implementation has the following limitations:
@@ -109,10 +110,12 @@ class SpecDecodeWorker:
block_size, gpu_memory_utilization, cpu_swap_space,
cache_dtype))
scorer_cache_block_size_bytes = self.scorer_worker.get_cache_block_size_bytes(
block_size, cache_dtype)
proposer_cache_block_size_bytes = self.proposer_worker.get_cache_block_size_bytes(
block_size, cache_dtype)
scorer_cache_block_size_bytes = (
self.scorer_worker.get_cache_block_size_bytes(
block_size, cache_dtype))
proposer_cache_block_size_bytes = (
self.proposer_worker.get_cache_block_size_bytes(
block_size, cache_dtype))
new_num_gpu_blocks = split_num_cache_blocks_evenly(
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
@@ -320,8 +323,8 @@ class SpecDecodeWorker:
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
maybe_rejsample_metrics = self._metrics.maybe_collect_rejsample_metrics(
k)
maybe_rejsample_metrics = (
self._metrics.maybe_collect_rejsample_metrics(k))
if maybe_rejsample_metrics is not None:
sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics