Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -35,7 +35,7 @@ from .mk_objects import (
|
||||
from .parallel_utils import ProcessGroupInfo
|
||||
|
||||
|
||||
def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
|
||||
def _describe_tensor(t: torch.Tensor | None, name: str) -> str:
|
||||
if t is None:
|
||||
return f"{name} : None"
|
||||
else:
|
||||
@@ -44,21 +44,21 @@ def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
Ms: Union[list[int], int]
|
||||
Ms: list[int] | int
|
||||
K: int
|
||||
N: int
|
||||
E: int
|
||||
topks: Union[list[int], int]
|
||||
topks: list[int] | int
|
||||
dtype: torch.dtype
|
||||
quant_config: Optional[TestMoEQuantConfig]
|
||||
quant_config: TestMoEQuantConfig | None
|
||||
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
|
||||
|
||||
fused_moe_chunk_size: Optional[int]
|
||||
fused_moe_chunk_size: int | None
|
||||
world_size: int
|
||||
|
||||
torch_trace_dir_path: Optional[str] = None
|
||||
torch_trace_dir_path: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.quant_config is None:
|
||||
@@ -93,7 +93,7 @@ class Config:
|
||||
return self.Ms
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Union[torch.dtype, str, None]:
|
||||
def quant_dtype(self) -> torch.dtype | str | None:
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@@ -112,7 +112,7 @@ class Config:
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
@property
|
||||
def quant_block_shape(self) -> Optional[list[int]]:
|
||||
def quant_block_shape(self) -> list[int] | None:
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@@ -209,7 +209,7 @@ class Config:
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.backend
|
||||
|
||||
def is_valid(self) -> tuple[bool, Optional[str]]:
|
||||
def is_valid(self) -> tuple[bool, str | None]:
|
||||
# Check prepare-finalize and fused-experts compatibility
|
||||
if self.is_batched_prepare_finalize():
|
||||
if not self.is_batched_fused_experts():
|
||||
@@ -280,10 +280,10 @@ class Config:
|
||||
class WeightTensors:
|
||||
w1: torch.Tensor
|
||||
w2: torch.Tensor
|
||||
w1_scale: Optional[torch.Tensor]
|
||||
w2_scale: Optional[torch.Tensor]
|
||||
w1_gs: Optional[torch.Tensor] = None
|
||||
w2_gs: Optional[torch.Tensor] = None
|
||||
w1_scale: torch.Tensor | None
|
||||
w2_scale: torch.Tensor | None
|
||||
w1_gs: torch.Tensor | None = None
|
||||
w2_gs: torch.Tensor | None = None
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
@@ -351,11 +351,11 @@ class WeightTensors:
|
||||
@dataclass
|
||||
class RankTensors:
|
||||
hidden_states: torch.Tensor
|
||||
hidden_states_scale: Optional[torch.Tensor]
|
||||
hidden_states_scale: torch.Tensor | None
|
||||
|
||||
topk_weights: torch.Tensor
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: Optional[torch.Tensor]
|
||||
expert_map: torch.Tensor | None
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
@@ -370,7 +370,7 @@ class RankTensors:
|
||||
@staticmethod
|
||||
def make_hidden_states(
|
||||
config: Config,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""
|
||||
Return hidden_states
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
import copy
|
||||
from enum import Enum
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
@@ -82,7 +81,7 @@ def make_feature_matrix(csv_file_path: str):
|
||||
import pandas as pd
|
||||
|
||||
def add_to_results(
|
||||
config: Config, success: Result, results_df: Optional[pd.DataFrame] = None
|
||||
config: Config, success: Result, results_df: pd.DataFrame | None = None
|
||||
):
|
||||
config_dict = asdict(config)
|
||||
config_dict["prepare_finalize_type"] = config_dict[
|
||||
@@ -121,7 +120,7 @@ def make_feature_matrix(csv_file_path: str):
|
||||
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES)
|
||||
)
|
||||
|
||||
results_df: Optional[pd.DataFrame] = None
|
||||
results_df: pd.DataFrame | None = None
|
||||
for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm(
|
||||
combinations
|
||||
):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -43,25 +42,25 @@ from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
@dataclass
|
||||
class TestMoEQuantConfig:
|
||||
quant_dtype: Union[torch.dtype, str, None]
|
||||
quant_dtype: torch.dtype | str | None
|
||||
per_out_ch_quant: bool
|
||||
per_act_token_quant: bool
|
||||
block_shape: Optional[list[int]]
|
||||
block_shape: list[int] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrepareFinalizeInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
supported_dtypes: list[Union[torch.dtype, str]]
|
||||
supported_dtypes: list[torch.dtype | str]
|
||||
blocked_quantization_support: bool
|
||||
backend: Optional[str]
|
||||
backend: str | None
|
||||
supports_apply_weight_on_input: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
supported_dtypes: list[Union[torch.dtype, str]]
|
||||
supported_dtypes: list[torch.dtype | str]
|
||||
blocked_quantization_support: bool
|
||||
supports_chunking: bool
|
||||
supports_expert_map: bool
|
||||
@@ -78,7 +77,7 @@ MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
|
||||
|
||||
standard_format = mk.FusedMoEActivationFormat.Standard
|
||||
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
|
||||
common_float_types: list[Union[torch.dtype, str]] = [
|
||||
common_float_types: list[torch.dtype | str] = [
|
||||
torch.float8_e4m3fn,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
@@ -92,9 +91,9 @@ fp8_types = [torch.float8_e4m3fn]
|
||||
def register_prepare_and_finalize(
|
||||
kind,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
supported_dtypes: list[Union[torch.dtype, str]],
|
||||
supported_dtypes: list[torch.dtype | str],
|
||||
blocked_quantization_support: bool,
|
||||
backend: Optional[str],
|
||||
backend: str | None,
|
||||
force_multigpu: bool = False,
|
||||
supports_apply_weight_on_input: bool = True,
|
||||
):
|
||||
@@ -121,7 +120,7 @@ def register_prepare_and_finalize(
|
||||
def register_experts(
|
||||
kind,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
supported_dtypes: list[Union[torch.dtype, str]],
|
||||
supported_dtypes: list[torch.dtype | str],
|
||||
blocked_quantization_support: bool,
|
||||
supports_chunking: bool,
|
||||
supports_expert_map: bool,
|
||||
@@ -340,7 +339,7 @@ if cutlass_fp4_supported():
|
||||
supports_expert_map=False,
|
||||
)
|
||||
|
||||
MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
|
||||
MK_QUANT_CONFIGS: list[TestMoEQuantConfig | None] = [
|
||||
None,
|
||||
# per-channel / per-column weights and per-tensor activations
|
||||
TestMoEQuantConfig(
|
||||
@@ -395,7 +394,7 @@ if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
|
||||
|
||||
def make_prepare_finalize(
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
backend: Optional[str],
|
||||
backend: str | None,
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
|
||||
@@ -3,11 +3,12 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import traceback
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Concatenate
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
||||
@@ -58,9 +59,9 @@ def _worker_parallel_launch(
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, P], None],
|
||||
vllm_config: Optional[VllmConfig],
|
||||
env_dict: Optional[dict],
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig | None, Any, P], None],
|
||||
vllm_config: VllmConfig | None,
|
||||
env_dict: dict | None,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from collections.abc import Callable
|
||||
from itertools import product
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user