mypy type checking for vllm/worker (#11418)
Signed-off-by: lucast2021 <lucast2021@headroyce.org> Co-authored-by: lucast2021 <lucast2021@headroyce.org>
This commit is contained in:
@@ -333,9 +333,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
def prepare_worker_input(
|
def prepare_worker_input(
|
||||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||||
assert execute_model_req is not None
|
assert execute_model_req is not None
|
||||||
virtual_engine = execute_model_req.virtual_engine
|
virtual_engine: int = execute_model_req.virtual_engine
|
||||||
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
|
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
|
||||||
blocks_to_copy = execute_model_req.blocks_to_copy
|
|
||||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.int64).view(-1, 2)
|
dtype=torch.int64).view(-1, 2)
|
||||||
|
|||||||
@@ -406,8 +406,9 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
if not cont:
|
if not cont:
|
||||||
break
|
break
|
||||||
|
|
||||||
def _final_process_outputs(self, model_input: StatefulModelInput,
|
def _final_process_outputs(
|
||||||
output_proc_callback: Optional[Callable]):
|
self, model_input: StatefulModelInput,
|
||||||
|
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
|
||||||
assert model_input.frozen_model_input is not None
|
assert model_input.frozen_model_input is not None
|
||||||
|
|
||||||
has_async_callback = output_proc_callback is not None
|
has_async_callback = output_proc_callback is not None
|
||||||
@@ -594,8 +595,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
# should be [SamplerOutput]
|
# should be [SamplerOutput]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
|
def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
|
||||||
num_queries):
|
num_seqs: Optional[int], num_queries: int):
|
||||||
|
|
||||||
assert sampling_metadata.num_prompts == 0
|
assert sampling_metadata.num_prompts == 0
|
||||||
assert len(sampling_metadata.seq_groups) == num_queries
|
assert len(sampling_metadata.seq_groups) == num_queries
|
||||||
@@ -850,13 +851,13 @@ def _pythonize_sampler_output(
|
|||||||
seq_ids = seq_group.seq_ids
|
seq_ids = seq_group.seq_ids
|
||||||
next_token_ids = sample_result
|
next_token_ids = sample_result
|
||||||
parent_ids = [0]
|
parent_ids = [0]
|
||||||
|
seq_outputs: List[SequenceOutput]
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
completion_seq_group_output: CompletionSequenceGroupOutput = \
|
completion_seq_group_output: CompletionSequenceGroupOutput = \
|
||||||
cache.cached_completion_seq_group_output.get_object()
|
cache.cached_completion_seq_group_output.get_object()
|
||||||
completion_seq_group_output.samples.clear()
|
completion_seq_group_output.samples.clear()
|
||||||
seq_outputs: List[
|
seq_outputs = completion_seq_group_output.samples
|
||||||
SequenceOutput] = completion_seq_group_output.samples
|
|
||||||
else:
|
else:
|
||||||
seq_outputs = []
|
seq_outputs = []
|
||||||
|
|
||||||
|
|||||||
@@ -452,7 +452,7 @@ class WorkerWrapperBase:
|
|||||||
self.worker = worker_class(*args, **kwargs)
|
self.worker = worker_class(*args, **kwargs)
|
||||||
assert self.worker is not None
|
assert self.worker is not None
|
||||||
|
|
||||||
def execute_method(self, method, *args, **kwargs):
|
def execute_method(self, method: str, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
target = self if self.worker is None else self.worker
|
target = self if self.worker is None else self.worker
|
||||||
executor = getattr(target, method)
|
executor = getattr(target, method)
|
||||||
|
|||||||
Reference in New Issue
Block a user