Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module):
|
||||
else:
|
||||
self.forward_method = self.forward_native
|
||||
|
||||
def forward(self, draft_token_ids: List[List[int]],
|
||||
def forward(self, draft_token_ids: list[list[int]],
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||
if not sampling_metadata.all_greedy:
|
||||
@@ -66,7 +65,7 @@ class RejectionSampler(nn.Module):
|
||||
|
||||
def flashinfer_sample(
|
||||
self,
|
||||
draft_token_ids: List[List[int]],
|
||||
draft_token_ids: list[list[int]],
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
@@ -119,7 +118,7 @@ class RejectionSampler(nn.Module):
|
||||
# TODO: The following method can be optimized for better performance.
|
||||
def forward_native(
|
||||
self,
|
||||
draft_token_ids: List[List[int]],
|
||||
draft_token_ids: list[list[int]],
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
|
||||
Reference in New Issue
Block a user