Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -27,13 +26,13 @@ def triton_moe(
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
@@ -54,13 +53,13 @@ def batched_moe(
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
@@ -94,13 +93,13 @@ def naive_batched_moe(
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
@@ -129,8 +128,8 @@ def naive_batched_moe(
|
||||
|
||||
|
||||
def chunk_scales(
|
||||
scales: Optional[torch.Tensor], start: int, end: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
scales: torch.Tensor | None, start: int, end: int
|
||||
) -> torch.Tensor | None:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
@@ -144,10 +143,10 @@ def make_quantized_test_activations(
|
||||
m: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
|
||||
a_q = a
|
||||
a_scale = None
|
||||
@@ -172,11 +171,11 @@ def make_quantized_test_activations(
|
||||
|
||||
def moe_quantize_weights(
|
||||
w: torch.Tensor,
|
||||
w_s: Optional[torch.Tensor],
|
||||
quant_dtype: Union[torch.dtype, str, None],
|
||||
w_s: torch.Tensor | None,
|
||||
quant_dtype: torch.dtype | str | None,
|
||||
per_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
block_shape: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
assert (
|
||||
quant_dtype == torch.float8_e4m3fn
|
||||
or quant_dtype == torch.int8
|
||||
@@ -220,10 +219,10 @@ def make_test_weight(
|
||||
rows: int,
|
||||
cols: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
||||
w_gs = None
|
||||
|
||||
@@ -262,12 +261,12 @@ def make_test_weights(
|
||||
n: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> tuple[
|
||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
|
||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
|
||||
]:
|
||||
return (
|
||||
make_test_weight(
|
||||
@@ -295,9 +294,9 @@ def make_test_quant_config(
|
||||
n: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
|
||||
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
|
||||
e,
|
||||
@@ -310,8 +309,8 @@ def make_test_quant_config(
|
||||
)
|
||||
|
||||
# Hacky/trivial scales for nvfp4.
|
||||
a1_gscale: Optional[torch.Tensor] = None
|
||||
a2_gscale: Optional[torch.Tensor] = None
|
||||
a1_gscale: torch.Tensor | None = None
|
||||
a2_gscale: torch.Tensor | None = None
|
||||
if quant_dtype == "nvfp4":
|
||||
a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
@@ -348,9 +347,9 @@ def fused_moe(
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool = False,
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = fused_topk(
|
||||
hidden_states, score.float(), topk, renormalize
|
||||
@@ -378,7 +377,7 @@ class BaselineMM(torch.nn.Module):
|
||||
self.b = b.to(dtype=torch.float32)
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
|
||||
|
||||
|
||||
@@ -422,8 +421,8 @@ class RealMLP(torch.nn.Module):
|
||||
quant_config=None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
w1_s: Optional[torch.Tensor] = None,
|
||||
w2_s: Optional[torch.Tensor] = None,
|
||||
w1_s: torch.Tensor | None = None,
|
||||
w2_s: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -481,7 +480,7 @@ def make_shared_experts(
|
||||
N: int,
|
||||
K: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
) -> torch.nn.Module:
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
|
||||
|
||||
Reference in New Issue
Block a user