Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -13,7 +12,7 @@ def sampler():
|
||||
return RejectionSampler()
|
||||
|
||||
|
||||
def create_logits_tensor(token_ids: List[int],
|
||||
def create_logits_tensor(token_ids: list[int],
|
||||
vocab_size: int = 100) -> torch.Tensor:
|
||||
"""Helper function to create logits tensor that
|
||||
will produce desired token ids on argmax"""
|
||||
@@ -23,7 +22,7 @@ def create_logits_tensor(token_ids: List[int],
|
||||
return logits
|
||||
|
||||
|
||||
def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
|
||||
def create_sampling_metadata(spec_tokens: list[list[int]]) -> SamplingMetadata:
|
||||
batch_size = len(spec_tokens)
|
||||
return SamplingMetadata(
|
||||
temperature=torch.tensor([]),
|
||||
@@ -106,7 +105,7 @@ def test_single_token_sequence(sampler):
|
||||
|
||||
def test_empty_sequence(sampler):
|
||||
"""Test handling empty sequence of speculated tokens"""
|
||||
spec_tokens: List[List[int]] = [[]]
|
||||
spec_tokens: list[list[int]] = [[]]
|
||||
output_tokens = [5] # Just the bonus token
|
||||
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
|
||||
Reference in New Issue
Block a user