[Typing] Mypy typing part 2 (#4043)
Co-authored-by: SangBin Cho <sangcho@sangcho-LT93GQWG9C.local>
This commit is contained in:
@@ -106,7 +106,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
def _expand_batch(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_token_ids_list: List[TokenId],
|
||||
proposal_token_ids_list: List[List[TokenId]],
|
||||
proposal_lens_list: List[int],
|
||||
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
|
||||
"""Given the input sequences and potentially multiple corresponding
|
||||
@@ -218,7 +218,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
def _create_target_seq_group_metadata(
|
||||
self,
|
||||
input_seq_group_metadata: SequenceGroupMetadata,
|
||||
proposal_token_ids: List[TokenId], # shape: [batch_size, k]
|
||||
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||
batch_index: int,
|
||||
target_seq_ids_iter: Iterator[TargetSeqId],
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
@@ -360,7 +360,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
[0, 1, 2]
|
||||
[0, 1, 2, 3]
|
||||
"""
|
||||
empty_token_ids = []
|
||||
empty_token_ids: List[TokenId] = []
|
||||
|
||||
token_ids_to_score = [empty_token_ids]
|
||||
token_ids_to_score.extend([
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -73,5 +73,5 @@ class SpeculativeScorer(ABC):
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
k: int,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> SpeculativeScores:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -112,6 +112,7 @@ class AsyncMetricsCollector:
|
||||
|
||||
Returns a CUDA event recording when the copy is complete.
|
||||
"""
|
||||
assert self._copy_stream is not None
|
||||
self._copy_stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
with torch.cuda.stream(self._copy_stream):
|
||||
|
||||
@@ -26,7 +26,8 @@ class MultiStepWorker(Worker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._proposer: Optional[DraftModelTop1Proposer] = None
|
||||
# Lazy initialization list.
|
||||
self._proposer: DraftModelTop1Proposer
|
||||
|
||||
def init_device(self):
|
||||
super().init_device()
|
||||
@@ -338,10 +339,10 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
self._vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=self._device)
|
||||
proposal_lens = torch.zeros(len(proposal_lens),
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
return proposal_tokens, proposal_probs, proposal_lens
|
||||
proposal_lens_tensor = torch.zeros(len(proposal_lens),
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
|
||||
@@ -376,9 +377,9 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
proposal_tokens, proposal_probs = (entire_proposal_tokens,
|
||||
entire_proposal_probs)
|
||||
|
||||
proposal_lens = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_lens[nonzero_proposal_len_indices] = max_proposal_len
|
||||
proposal_lens_tensor = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len
|
||||
|
||||
return proposal_tokens, proposal_probs, proposal_lens
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
@@ -89,7 +89,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.probs_dtype = self.rejection_sampler.probs_dtype
|
||||
self.token_id_dtype = self.rejection_sampler.token_id_dtype
|
||||
|
||||
self.scorer: SpeculativeScorer = None
|
||||
# Lazy initiazliation.
|
||||
self.scorer: SpeculativeScorer
|
||||
|
||||
def init_device(self) -> None:
|
||||
"""Initialize both scorer and proposer models.
|
||||
@@ -233,6 +234,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
logger.info("get spec proposals")
|
||||
# Generate proposals using draft worker.
|
||||
assert blocks_to_swap_in is not None
|
||||
assert blocks_to_swap_out is not None
|
||||
assert blocks_to_copy is not None
|
||||
proposals = self.proposer_worker.get_spec_proposals(
|
||||
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
||||
blocks_to_copy, k)
|
||||
|
||||
Reference in New Issue
Block a user