Update deprecated type hinting in model_executor/layers (#18056)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-13 12:17:23 +01:00
committed by GitHub
parent 906f0598fc
commit 6223dd8114
87 changed files with 523 additions and 523 deletions

View File

@@ -3,7 +3,7 @@
import functools
import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Optional
import torch
@@ -472,14 +472,14 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: Dict[str, Any],
config: dict[str, Any],
compute_type: tl.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None) -> None:
block_shape: Optional[list[int]] = None) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
@@ -622,7 +622,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
def get_config_file_name(E: int,
N: int,
dtype: Optional[str],
block_shape: Optional[List[int]] = None) -> str:
block_shape: Optional[list[int]] = None) -> str:
device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = ("" if not block_shape or not all(block_shape) else
@@ -638,7 +638,7 @@ def get_moe_configs(
dtype: Optional[str],
block_n: Optional[int] = None,
block_k: Optional[int] = None,
) -> Optional[Dict[int, Any]]:
) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
@@ -670,7 +670,7 @@ def get_moe_configs(
return None
def get_moe_wna16_block_config(config: Dict[str,
def get_moe_wna16_block_config(config: dict[str,
int], use_moe_wna16_cuda: bool,
num_valid_tokens: int, size_k: int, size_n: int,
num_experts: int, group_size: int,
@@ -742,8 +742,8 @@ def get_default_config(
topk: int,
dtype: Optional[str],
is_marlin: bool,
block_shape: Optional[List[int]] = None,
) -> Dict[str, int]:
block_shape: Optional[list[int]] = None,
) -> dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
# BLOCK_SIZE_K must be divisible by block_shape[1]
@@ -795,13 +795,13 @@ def get_default_config(
def try_get_optimal_moe_config(
w1_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...],
w1_shape: tuple[int, ...],
w2_shape: tuple[int, ...],
top_k: int,
dtype: Optional[str],
M: int,
is_marlin: bool = False,
block_shape: Optional[List[int]] = None,
block_shape: Optional[list[int]] = None,
):
from vllm.model_executor.layers.fused_moe import get_config
override_config = get_config()
@@ -855,7 +855,7 @@ def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
@@ -895,7 +895,7 @@ def grouped_topk(
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
@@ -982,7 +982,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
block_shape: Optional[list[int]] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
@@ -1012,7 +1012,7 @@ def inplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
block_shape: Optional[list[int]] = None) -> None:
pass
@@ -1046,7 +1046,7 @@ def outplace_fused_experts(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
block_shape: Optional[list[int]] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
@@ -1076,7 +1076,7 @@ def outplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
block_shape: Optional[list[int]] = None) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -1129,7 +1129,7 @@ def fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
if (allow_deep_gemm and use_fp8_w8a8
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
@@ -1184,8 +1184,8 @@ def moe_kernel_prepare_input(
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if use_fp8_w8a8:
assert B_scale is not None
if block_shape is None:
@@ -1248,7 +1248,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None):
block_shape: Optional[list[int]] = None):
# Check constraints.
if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[
@@ -1452,7 +1452,7 @@ def fused_moe(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1497,7 +1497,7 @@ def fused_moe(
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
- block_shape: (Optional[list[int]]): Optional block size for block-wise
quantization.
Returns: