[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import random
|
||||
from typing import Dict, List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
@@ -210,7 +211,7 @@ def test_same_output_for_multi_step():
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output = []
|
||||
single_step_output: List[SamplerOutput] = []
|
||||
continuations = [[1] for _ in prompts]
|
||||
set_random_seed(seed)
|
||||
|
||||
@@ -232,11 +233,15 @@ def test_same_output_for_multi_step():
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Get token ids and logprobs for comparison.
|
||||
multi_step_output_logprobs = [[] for _ in prompts]
|
||||
single_step_output_logprobs = [[] for _ in prompts]
|
||||
multi_step_output_logprobs: List[List[Dict[int,
|
||||
Logprob]]] = [[]
|
||||
for _ in prompts]
|
||||
single_step_output_logprobs: List[List[Dict[int,
|
||||
Logprob]]] = [[]
|
||||
for _ in prompts]
|
||||
|
||||
multi_step_output_token_ids = [[] for _ in prompts]
|
||||
single_step_output_token_ids = [[] for _ in prompts]
|
||||
multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
|
||||
single_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
|
||||
for i, _ in enumerate(prompts):
|
||||
for multi_step, single_step in zip(multi_step_output,
|
||||
single_step_output):
|
||||
|
||||
Reference in New Issue
Block a user