[Misc] Various simplifications and typing fixes (#5368)
This commit is contained in:
@@ -80,7 +80,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
execute_model_req=execute_model_req.clone(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list, ))
|
||||
seq_group_metadata_list=target_seq_group_metadata_list))
|
||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
@@ -140,8 +140,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
num_scoring_tokens)
|
||||
|
||||
def _contract_batch(
|
||||
self, contracted_bs: int,
|
||||
target_sampler_output: List[SamplerOutput],
|
||||
self, contracted_bs: int, target_sampler_output: SamplerOutput,
|
||||
proposals: SpeculativeProposals, num_scoring_tokens: int,
|
||||
non_spec_indices: List[int], spec_indices: List[int],
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
@@ -167,30 +166,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
|
||||
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
||||
|
||||
target_token_ids = target_token_ids.squeeze().reshape(
|
||||
spec_expanded_bs, k + 1)
|
||||
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
|
||||
self._vocab_size)
|
||||
target_logprobs = target_logprobs.squeeze().reshape(
|
||||
spec_expanded_bs, k + 1, self._vocab_size)
|
||||
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
|
||||
target_probs = target_probs.reshape(*target_token_ids.shape,
|
||||
self._vocab_size)
|
||||
target_logprobs = target_logprobs.reshape(target_probs.shape)
|
||||
|
||||
all_tokens = torch.full(size=(contracted_bs, k + 1),
|
||||
fill_value=-1,
|
||||
device=self._device,
|
||||
dtype=torch.long)
|
||||
all_probs = torch.zeros(contracted_bs,
|
||||
k + 1,
|
||||
self._vocab_size,
|
||||
device=self._device,
|
||||
dtype=torch.float32)
|
||||
all_logprobs = torch.full(size=(
|
||||
contracted_bs,
|
||||
k + 1,
|
||||
self._vocab_size,
|
||||
),
|
||||
fill_value=-float("inf"),
|
||||
device=self._device,
|
||||
dtype=torch.float32)
|
||||
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
|
||||
fill_value=-1)
|
||||
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
|
||||
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
||||
fill_value=-float("inf"))
|
||||
|
||||
if non_spec_indices:
|
||||
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import SpeculativeConfig
|
||||
from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
@@ -30,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
|
||||
"""
|
||||
assert "speculative_config" in kwargs
|
||||
speculative_config = kwargs.get("speculative_config")
|
||||
speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
|
||||
assert speculative_config is not None
|
||||
|
||||
target_worker = Worker(*args, **kwargs)
|
||||
@@ -109,12 +110,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
logger.info("Configuring SpecDecodeWorker with proposer=%s",
|
||||
type(proposer_worker))
|
||||
|
||||
return SpecDecodeWorker(
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
rejection_sampler=RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens, ))
|
||||
return SpecDecodeWorker(proposer_worker,
|
||||
scorer_worker,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
rejection_sampler=RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -148,7 +148,8 @@ class Top1Proposer(SpeculativeProposer):
|
||||
nonzero_proposal_len_indices,
|
||||
)
|
||||
|
||||
def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output,
|
||||
@staticmethod
|
||||
def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
|
||||
nonzero_proposal_len_indices, transposed):
|
||||
"""Remove sequences from nonzero_proposal_len_indices and reset
|
||||
their proposal_len to 0 the draft worker does not provide a proposal
|
||||
@@ -207,7 +208,7 @@ class Top1Proposer(SpeculativeProposer):
|
||||
self,
|
||||
batch_size: int,
|
||||
proposal_len: int,
|
||||
maybe_sampler_output: Optional[SamplerOutput],
|
||||
maybe_sampler_output: Optional[List[SamplerOutput]],
|
||||
proposal_lens: List[int],
|
||||
nonzero_proposal_len_indices: List[int],
|
||||
sampler_transposed: bool,
|
||||
@@ -218,25 +219,19 @@ class Top1Proposer(SpeculativeProposer):
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty proposals.
|
||||
proposal_tokens = torch.full(
|
||||
size=(
|
||||
batch_size,
|
||||
proposal_len,
|
||||
),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device,
|
||||
)
|
||||
proposal_probs = torch.zeros(
|
||||
batch_size,
|
||||
proposal_len,
|
||||
self._vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=self._device,
|
||||
)
|
||||
proposal_lens_tensor = torch.zeros(len(proposal_lens),
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_tokens = torch.tensor(-1,
|
||||
dtype=torch.long,
|
||||
device=self._device).expand(
|
||||
batch_size, proposal_len)
|
||||
proposal_probs = torch.tensor(0,
|
||||
dtype=torch.float32,
|
||||
device=self._device).expand(
|
||||
batch_size, proposal_len,
|
||||
self._vocab_size)
|
||||
proposal_lens_tensor = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=self._device).expand(
|
||||
len(proposal_lens))
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
@@ -246,18 +241,14 @@ class Top1Proposer(SpeculativeProposer):
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
||||
|
||||
entire_proposal_tokens = torch.full(
|
||||
entire_proposal_tokens = proposal_tokens.new_full(
|
||||
size=(batch_size, *proposal_tokens.shape[1:]),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device,
|
||||
)
|
||||
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
||||
entire_proposal_probs = torch.zeros(
|
||||
entire_proposal_probs = proposal_probs.new_zeros(
|
||||
batch_size,
|
||||
*proposal_probs.shape[1:],
|
||||
dtype=torch.float32,
|
||||
device=self._device,
|
||||
)
|
||||
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from contextlib import contextmanager
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceGroupOutput, SequenceOutput)
|
||||
SequenceOutput)
|
||||
|
||||
SeqId = int
|
||||
|
||||
@@ -16,11 +15,7 @@ def get_all_seq_ids(
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all
|
||||
sequence ids.
|
||||
"""
|
||||
return list(
|
||||
chain.from_iterable([
|
||||
seq_group_metadata.seq_data.keys()
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
]))
|
||||
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
|
||||
|
||||
|
||||
def get_all_num_logprobs(
|
||||
@@ -68,7 +63,7 @@ def create_sequence_group_output(
|
||||
seq_id: SeqId,
|
||||
topk_token_ids: List[int],
|
||||
topk_logprobs: List[float],
|
||||
) -> SequenceGroupOutput:
|
||||
) -> CompletionSequenceGroupOutput:
|
||||
"""Create a SequenceGroupOutput given the sampling results.
|
||||
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user