Update deprecated type hinting in vllm/compilation (#18072)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-13 16:32:48 +01:00
committed by GitHub
parent fc407a1425
commit 19324d660c
13 changed files with 70 additions and 69 deletions

View File

@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
import torch._inductor.pattern_matcher as pm
@@ -125,7 +125,7 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(mm_1)
rmsnorm = torch.ops.higher_order.auto_functionalized(
@@ -142,7 +142,7 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
@@ -190,7 +190,7 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(mm_1)
rmsnorm = torch.ops.higher_order.auto_functionalized(
@@ -207,7 +207,7 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default(