Fix phi4-mm and remove cuda binding (#35964)
Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
@@ -585,10 +585,9 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
enc_streaming_mask = self._streaming_mask(
|
||||
seq_len, batch_size, self.chunk_size, self.left_chunk
|
||||
)
|
||||
|
||||
if xs_pad.is_cuda:
|
||||
enc_streaming_mask = enc_streaming_mask.cuda()
|
||||
xs_pad = xs_pad.cuda()
|
||||
device = xs_pad.device
|
||||
enc_streaming_mask = enc_streaming_mask.to(device)
|
||||
xs_pad = xs_pad.to(device)
|
||||
|
||||
input_tensor = xs_pad
|
||||
input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
|
||||
@@ -605,8 +604,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
enc_streaming_mask_nc = self._streaming_mask(
|
||||
seq_len, batch_size, chunk_size_nc, left_chunk_nc
|
||||
)
|
||||
if xs_pad.is_cuda:
|
||||
enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
|
||||
if device.type != "cpu":
|
||||
enc_streaming_mask_nc = enc_streaming_mask_nc.to(device)
|
||||
if masks is not None:
|
||||
hs_mask_nc = masks & enc_streaming_mask_nc
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user