Fix phi4-mm and remove cuda binding (#35964)

Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma
2026-03-05 01:08:05 +08:00
committed by GitHub
parent e86221deb6
commit 58cfe0dc44
2 changed files with 22 additions and 19 deletions

View File

@@ -585,10 +585,9 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
enc_streaming_mask = self._streaming_mask(
seq_len, batch_size, self.chunk_size, self.left_chunk
)
if xs_pad.is_cuda:
enc_streaming_mask = enc_streaming_mask.cuda()
xs_pad = xs_pad.cuda()
device = xs_pad.device
enc_streaming_mask = enc_streaming_mask.to(device)
xs_pad = xs_pad.to(device)
input_tensor = xs_pad
input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
@@ -605,8 +604,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
enc_streaming_mask_nc = self._streaming_mask(
seq_len, batch_size, chunk_size_nc, left_chunk_nc
)
if xs_pad.is_cuda:
enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
if device.type != "cpu":
enc_streaming_mask_nc = enc_streaming_mask_nc.to(device)
if masks is not None:
hs_mask_nc = masks & enc_streaming_mask_nc
else: