[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple, TypedDict
|
||||
|
||||
import ray
|
||||
import torch
|
||||
@@ -12,8 +12,17 @@ from transformers import AutoConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
BLOCK_SIZE_M: int
|
||||
BLOCK_SIZE_N: int
|
||||
BLOCK_SIZE_K: int
|
||||
GROUP_SIZE_M: int
|
||||
num_warps: int
|
||||
num_stages: int
|
||||
|
||||
|
||||
def benchmark_config(
|
||||
config: Dict[str, int],
|
||||
config: BenchmarkConfig,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
@@ -92,7 +101,7 @@ def benchmark_config(
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies = []
|
||||
latencies: List[float] = []
|
||||
for i in range(num_iters):
|
||||
prepare(i)
|
||||
torch.cuda.synchronize()
|
||||
@@ -111,7 +120,7 @@ def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||
# Reduced search space for faster tuning.
|
||||
# TODO(woosuk): Increase the search space and use a performance model to
|
||||
# prune the search space.
|
||||
configs = []
|
||||
configs: List[BenchmarkConfig] = []
|
||||
for num_stages in [2, 3, 4, 5]:
|
||||
for block_m in [16, 32, 64, 128, 256]:
|
||||
for block_k in [64, 128, 256]:
|
||||
@@ -175,8 +184,8 @@ class BenchmarkWorker:
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8: bool,
|
||||
search_space: List[Dict[str, int]],
|
||||
) -> Dict[str, int]:
|
||||
search_space: List[BenchmarkConfig],
|
||||
) -> BenchmarkConfig:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
for config in tqdm(search_space):
|
||||
@@ -199,10 +208,11 @@ class BenchmarkWorker:
|
||||
best_config = config
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||
assert best_config is not None
|
||||
return best_config
|
||||
|
||||
|
||||
def sort_config(config: Dict[str, int]) -> Dict[str, int]:
|
||||
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
return {
|
||||
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||
@@ -214,7 +224,7 @@ def sort_config(config: Dict[str, int]) -> Dict[str, int]:
|
||||
|
||||
|
||||
def save_configs(
|
||||
configs: Dict[int, Dict[str, int]],
|
||||
configs: Dict[int, BenchmarkConfig],
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
|
||||
Reference in New Issue
Block a user