Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List
import pytest
import torch
@@ -13,7 +12,7 @@ def sampler():
return RejectionSampler()
def create_logits_tensor(token_ids: List[int],
def create_logits_tensor(token_ids: list[int],
vocab_size: int = 100) -> torch.Tensor:
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
@@ -23,7 +22,7 @@ def create_logits_tensor(token_ids: List[int],
return logits
def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
def create_sampling_metadata(spec_tokens: list[list[int]]) -> SamplingMetadata:
batch_size = len(spec_tokens)
return SamplingMetadata(
temperature=torch.tensor([]),
@@ -106,7 +105,7 @@ def test_single_token_sequence(sampler):
def test_empty_sequence(sampler):
"""Test handling empty sequence of speculated tokens"""
spec_tokens: List[List[int]] = [[]]
spec_tokens: list[list[int]] = [[]]
output_tokens = [5] # Just the bonus token
metadata = create_sampling_metadata(spec_tokens)