[Refactor] Clean up pooling serial utils (#33665)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -5,17 +5,19 @@ import torch
|
||||
|
||||
from tests.models.utils import check_embeddings_close
|
||||
from vllm.utils.serial_utils import (
|
||||
EMBED_DTYPE_TO_TORCH_DTYPE,
|
||||
EMBED_DTYPES,
|
||||
ENDIANNESS,
|
||||
EmbedDType,
|
||||
Endianness,
|
||||
binary2tensor,
|
||||
tensor2binary,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("endianness", ENDIANNESS)
|
||||
@pytest.mark.parametrize("embed_dtype", EMBED_DTYPE_TO_TORCH_DTYPE.keys())
|
||||
@pytest.mark.parametrize("embed_dtype", EMBED_DTYPES.keys())
|
||||
@torch.inference_mode()
|
||||
def test_encode_and_decode(embed_dtype: str, endianness: str):
|
||||
def test_encode_and_decode(embed_dtype: EmbedDType, endianness: Endianness):
|
||||
for i in range(10):
|
||||
tensor = torch.rand(2, 3, 5, 7, 11, 13, device="cpu", dtype=torch.float32)
|
||||
shape = tensor.shape
|
||||
|
||||
Reference in New Issue
Block a user