[Misc] Consolidate and optimize logic for building padded tensors (#6541)

This commit is contained in:
Cyrus Leung
2024-07-20 12:17:24 +08:00
committed by GitHub
parent 3f8d42c81f
commit 9042d68362
9 changed files with 77 additions and 75 deletions

View File

@@ -20,6 +20,7 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Union)
import numpy as np
import numpy.typing as npt
import psutil
import torch
import torch.types
@@ -40,6 +41,15 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"fp8_e5m2": torch.uint8,
}
TORCH_DTYPE_TO_NUMPY_DTYPE = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
torch.uint8: np.uint8,
torch.int32: np.int32,
torch.int64: np.int64,
}
P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
@@ -617,23 +627,54 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
f"(e.g., 1, 2, 3). Given input: {s}") from e
def make_tensor_with_pad(
x: List[List[int]],
max_len: int,
pad: int,
dtype: torch.dtype,
device: Optional[Union[str, torch.device]],
) -> torch.Tensor:
"""Make a padded tensor of a 2D inputs.
def make_ndarray_with_pad(
x: List[List[T]],
pad: T,
dtype: npt.DTypeLike,
*,
max_len: Optional[int] = None,
) -> npt.NDArray:
"""
Make a padded array from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad
if max_len is None:
# Unlike for most functions, map is faster than a genexpr over `len`
max_len = max(map(len, x), default=0)
padded_x = np.full((len(x), max_len), pad, dtype=dtype)
for ind, blocktb in enumerate(x):
assert len(blocktb) <= max_len
padded_x[ind, :len(blocktb)] = blocktb
return torch.tensor(padded_x, dtype=dtype, device=device)
return padded_x
def make_tensor_with_pad(
x: List[List[T]],
pad: T,
dtype: torch.dtype,
*,
max_len: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
Make a padded tensor from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory()
return tensor
def async_tensor_h2d(