Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -134,7 +132,7 @@ def generate_continous_batched_examples(example_lens_by_batch,
|
||||
# given a tuple of lengths for each example in the batch
|
||||
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
|
||||
# 4 examples from second eg, etc
|
||||
def get_continuous_batch(example_lens: Tuple[int, ...]):
|
||||
def get_continuous_batch(example_lens: tuple[int, ...]):
|
||||
|
||||
indices = []
|
||||
for i, x in enumerate(example_lens):
|
||||
@@ -264,8 +262,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
|
||||
# hold state during the cutting process so we know if an
|
||||
# example has been exhausted and needs to cycle
|
||||
last_taken: Dict = {} # map: eg -> pointer to last taken sample
|
||||
exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted
|
||||
last_taken: dict = {} # map: eg -> pointer to last taken sample
|
||||
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||
|
||||
states = None
|
||||
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B,
|
||||
|
||||
Reference in New Issue
Block a user