[Typing] Mypy typing part 2 (#4043)

Co-authored-by: SangBin Cho <sangcho@sangcho-LT93GQWG9C.local>
This commit is contained in:
SangBin Cho
2024-04-18 09:28:43 +09:00
committed by GitHub
parent a53222544c
commit 533d2a1f39
20 changed files with 180 additions and 126 deletions

View File

@@ -106,7 +106,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def _expand_batch(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids_list: List[TokenId],
proposal_token_ids_list: List[List[TokenId]],
proposal_lens_list: List[int],
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
"""Given the input sequences and potentially multiple corresponding
@@ -218,7 +218,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def _create_target_seq_group_metadata(
self,
input_seq_group_metadata: SequenceGroupMetadata,
proposal_token_ids: List[TokenId], # shape: [batch_size, k]
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
batch_index: int,
target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
@@ -360,7 +360,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
[0, 1, 2]
[0, 1, 2, 3]
"""
empty_token_ids = []
empty_token_ids: List[TokenId] = []
token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend([

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
import torch
@@ -73,5 +73,5 @@ class SpeculativeScorer(ABC):
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> SpeculativeScores:
raise NotImplementedError

View File

@@ -112,6 +112,7 @@ class AsyncMetricsCollector:
Returns a CUDA event recording when the copy is complete.
"""
assert self._copy_stream is not None
self._copy_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._copy_stream):

View File

@@ -26,7 +26,8 @@ class MultiStepWorker(Worker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._proposer: Optional[DraftModelTop1Proposer] = None
# Lazy initialization list.
self._proposer: DraftModelTop1Proposer
def init_device(self):
super().init_device()
@@ -338,10 +339,10 @@ class DraftModelTop1Proposer(SpeculativeProposer):
self._vocab_size,
dtype=torch.float32,
device=self._device)
proposal_lens = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens
proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
@@ -376,9 +377,9 @@ class DraftModelTop1Proposer(SpeculativeProposer):
proposal_tokens, proposal_probs = (entire_proposal_tokens,
entire_proposal_probs)
proposal_lens = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens[nonzero_proposal_len_indices] = max_proposal_len
proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len
return proposal_tokens, proposal_probs, proposal_lens
return proposal_tokens, proposal_probs, proposal_lens_tensor

View File

@@ -89,7 +89,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.probs_dtype = self.rejection_sampler.probs_dtype
self.token_id_dtype = self.rejection_sampler.token_id_dtype
self.scorer: SpeculativeScorer = None
# Lazy initiazliation.
self.scorer: SpeculativeScorer
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
@@ -233,6 +234,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger.info("get spec proposals")
# Generate proposals using draft worker.
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k)