Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List
import torch
import torch.nn as nn
@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module):
else:
self.forward_method = self.forward_native
def forward(self, draft_token_ids: List[List[int]],
def forward(self, draft_token_ids: list[list[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
if not sampling_metadata.all_greedy:
@@ -66,7 +65,7 @@ class RejectionSampler(nn.Module):
def flashinfer_sample(
self,
draft_token_ids: List[List[int]],
draft_token_ids: list[list[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
@@ -119,7 +118,7 @@ class RejectionSampler(nn.Module):
# TODO: The following method can be optimized for better performance.
def forward_native(
self,
draft_token_ids: List[List[int]],
draft_token_ids: list[list[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput: