[Misc] Consolidate and optimize logic for building padded tensors (#6541)
This commit is contained in:
@@ -20,6 +20,7 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
|
||||
Union)
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import psutil
|
||||
import torch
|
||||
import torch.types
|
||||
@@ -40,6 +41,15 @@ STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"fp8_e5m2": torch.uint8,
|
||||
}
|
||||
|
||||
TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
||||
torch.float16: np.float16,
|
||||
torch.float32: np.float32,
|
||||
torch.float64: np.float64,
|
||||
torch.uint8: np.uint8,
|
||||
torch.int32: np.int32,
|
||||
torch.int64: np.int64,
|
||||
}
|
||||
|
||||
P = ParamSpec('P')
|
||||
K = TypeVar("K")
|
||||
T = TypeVar("T")
|
||||
@@ -617,23 +627,54 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
|
||||
f"(e.g., 1, 2, 3). Given input: {s}") from e
|
||||
|
||||
|
||||
def make_tensor_with_pad(
|
||||
x: List[List[int]],
|
||||
max_len: int,
|
||||
pad: int,
|
||||
dtype: torch.dtype,
|
||||
device: Optional[Union[str, torch.device]],
|
||||
) -> torch.Tensor:
|
||||
"""Make a padded tensor of a 2D inputs.
|
||||
def make_ndarray_with_pad(
|
||||
x: List[List[T]],
|
||||
pad: T,
|
||||
dtype: npt.DTypeLike,
|
||||
*,
|
||||
max_len: Optional[int] = None,
|
||||
) -> npt.NDArray:
|
||||
"""
|
||||
Make a padded array from 2D inputs.
|
||||
|
||||
The padding is applied to the end of each inner list until it reaches
|
||||
`max_len`.
|
||||
"""
|
||||
padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad
|
||||
if max_len is None:
|
||||
# Unlike for most functions, map is faster than a genexpr over `len`
|
||||
max_len = max(map(len, x), default=0)
|
||||
|
||||
padded_x = np.full((len(x), max_len), pad, dtype=dtype)
|
||||
for ind, blocktb in enumerate(x):
|
||||
assert len(blocktb) <= max_len
|
||||
padded_x[ind, :len(blocktb)] = blocktb
|
||||
return torch.tensor(padded_x, dtype=dtype, device=device)
|
||||
|
||||
return padded_x
|
||||
|
||||
|
||||
def make_tensor_with_pad(
|
||||
x: List[List[T]],
|
||||
pad: T,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
max_len: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
pin_memory: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Make a padded tensor from 2D inputs.
|
||||
|
||||
The padding is applied to the end of each inner list until it reaches
|
||||
`max_len`.
|
||||
"""
|
||||
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
|
||||
padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
|
||||
|
||||
tensor = torch.from_numpy(padded_x).to(device)
|
||||
if pin_memory:
|
||||
tensor = tensor.pin_memory()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def async_tensor_h2d(
|
||||
|
||||
Reference in New Issue
Block a user