[Bugfix] Fix isinstance check for tensor types in _load_prompt_embeds to use dtype comparison (#21612)
Signed-off-by: Alexandre Juan <a.juan@netheos.net>
This commit is contained in:
@@ -957,9 +957,11 @@ class OpenAIServing:
|
||||
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
|
||||
tensor = torch.load(io.BytesIO(base64.b64decode(embed)),
|
||||
weights_only=True)
|
||||
assert isinstance(
|
||||
tensor,
|
||||
(torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor))
|
||||
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
)
|
||||
if tensor.dim() > 2:
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.dim() == 2
|
||||
|
||||
Reference in New Issue
Block a user