Update deprecated type hinting in model_executor/layers (#18056)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -472,14 +472,14 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: Dict[str, Any],
|
||||
config: dict[str, Any],
|
||||
compute_type: tl.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
block_shape: Optional[list[int]] = None) -> None:
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
@@ -622,7 +622,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
def get_config_file_name(E: int,
|
||||
N: int,
|
||||
dtype: Optional[str],
|
||||
block_shape: Optional[List[int]] = None) -> str:
|
||||
block_shape: Optional[list[int]] = None) -> str:
|
||||
device_name = current_platform.get_device_name().replace(" ", "_")
|
||||
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
||||
block_shape_selector = ("" if not block_shape or not all(block_shape) else
|
||||
@@ -638,7 +638,7 @@ def get_moe_configs(
|
||||
dtype: Optional[str],
|
||||
block_n: Optional[int] = None,
|
||||
block_k: Optional[int] = None,
|
||||
) -> Optional[Dict[int, Any]]:
|
||||
) -> Optional[dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the fused MoE kernel.
|
||||
|
||||
@@ -670,7 +670,7 @@ def get_moe_configs(
|
||||
return None
|
||||
|
||||
|
||||
def get_moe_wna16_block_config(config: Dict[str,
|
||||
def get_moe_wna16_block_config(config: dict[str,
|
||||
int], use_moe_wna16_cuda: bool,
|
||||
num_valid_tokens: int, size_k: int, size_n: int,
|
||||
num_experts: int, group_size: int,
|
||||
@@ -742,8 +742,8 @@ def get_default_config(
|
||||
topk: int,
|
||||
dtype: Optional[str],
|
||||
is_marlin: bool,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
) -> Dict[str, int]:
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> dict[str, int]:
|
||||
if dtype == "fp8_w8a8" and block_shape is not None:
|
||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
||||
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
||||
@@ -795,13 +795,13 @@ def get_default_config(
|
||||
|
||||
|
||||
def try_get_optimal_moe_config(
|
||||
w1_shape: Tuple[int, ...],
|
||||
w2_shape: Tuple[int, ...],
|
||||
w1_shape: tuple[int, ...],
|
||||
w2_shape: tuple[int, ...],
|
||||
top_k: int,
|
||||
dtype: Optional[str],
|
||||
M: int,
|
||||
is_marlin: bool = False,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe import get_config
|
||||
override_config = get_config()
|
||||
@@ -855,7 +855,7 @@ def fused_topk(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
|
||||
@@ -895,7 +895,7 @@ def grouped_topk(
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
@@ -982,7 +982,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
block_shape: Optional[list[int]] = None) -> None:
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||
@@ -1012,7 +1012,7 @@ def inplace_fused_experts_fake(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
block_shape: Optional[list[int]] = None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1046,7 +1046,7 @@ def outplace_fused_experts(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
block_shape: Optional[list[int]] = None) -> torch.Tensor:
|
||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
False, activation, apply_router_weight_on_input,
|
||||
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
|
||||
@@ -1076,7 +1076,7 @@ def outplace_fused_experts_fake(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
block_shape: Optional[list[int]] = None) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@@ -1129,7 +1129,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
if (allow_deep_gemm and use_fp8_w8a8
|
||||
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
||||
@@ -1184,8 +1184,8 @@ def moe_kernel_prepare_input(
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if use_fp8_w8a8:
|
||||
assert B_scale is not None
|
||||
if block_shape is None:
|
||||
@@ -1248,7 +1248,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None):
|
||||
block_shape: Optional[list[int]] = None):
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.shape[1] // 2 == w1.shape[
|
||||
@@ -1452,7 +1452,7 @@ def fused_moe(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
@@ -1497,7 +1497,7 @@ def fused_moe(
|
||||
a1.
|
||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
a2.
|
||||
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
||||
- block_shape: (Optional[list[int]]): Optional block size for block-wise
|
||||
quantization.
|
||||
|
||||
Returns:
|
||||
|
||||
Reference in New Issue
Block a user