Fix phi4-mm and remove cuda binding (#35964)

Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma
2026-03-05 01:08:05 +08:00
committed by GitHub
parent e86221deb6
commit 58cfe0dc44
2 changed files with 22 additions and 19 deletions

View File

@@ -585,10 +585,9 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
enc_streaming_mask = self._streaming_mask( enc_streaming_mask = self._streaming_mask(
seq_len, batch_size, self.chunk_size, self.left_chunk seq_len, batch_size, self.chunk_size, self.left_chunk
) )
device = xs_pad.device
if xs_pad.is_cuda: enc_streaming_mask = enc_streaming_mask.to(device)
enc_streaming_mask = enc_streaming_mask.cuda() xs_pad = xs_pad.to(device)
xs_pad = xs_pad.cuda()
input_tensor = xs_pad input_tensor = xs_pad
input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
@@ -605,8 +604,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
enc_streaming_mask_nc = self._streaming_mask( enc_streaming_mask_nc = self._streaming_mask(
seq_len, batch_size, chunk_size_nc, left_chunk_nc seq_len, batch_size, chunk_size_nc, left_chunk_nc
) )
if xs_pad.is_cuda: if device.type != "cpu":
enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() enc_streaming_mask_nc = enc_streaming_mask_nc.to(device)
if masks is not None: if masks is not None:
hs_mask_nc = masks & enc_streaming_mask_nc hs_mask_nc = masks & enc_streaming_mask_nc
else: else:

View File

@@ -1309,16 +1309,15 @@ class NemoConvSubsampling(torch.nn.Module):
raise ValueError(f"Not valid sub-sampling: {subsampling}!") raise ValueError(f"Not valid sub-sampling: {subsampling}!")
if subsampling in ["dw_striding", "striding"]: if subsampling in ["dw_striding", "striding"]:
in_length = torch.tensor(feat_in, dtype=torch.float) out_length = calc_length_int(
out_length = calc_length( lengths=feat_in,
lengths=in_length,
all_paddings=self._left_padding + self._right_padding, all_paddings=self._left_padding + self._right_padding,
kernel_size=self._kernel_size, kernel_size=self._kernel_size,
stride=self._stride, stride=self._stride,
ceil_mode=self._ceil_mode, ceil_mode=self._ceil_mode,
repeat_num=self._sampling_num, repeat_num=self._sampling_num,
) )
self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) self.out = torch.nn.Linear(conv_channels * out_length, feat_out)
self.conv2d_subsampling = True self.conv2d_subsampling = True
elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]: elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
self.out = None self.out = None
@@ -1543,22 +1542,27 @@ class NemoConvSubsampling(torch.nn.Module):
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
def calc_length( def calc_length_int(
lengths: Tensor, lengths: int,
all_paddings: int, all_paddings: int,
kernel_size: int, kernel_size: int,
stride: int, stride: int,
ceil_mode: bool, ceil_mode: bool,
repeat_num: int = 1, repeat_num: int = 1,
) -> Tensor: ) -> int:
"""Calculates the output length of a Tensor passed through a convolution or """Integer-only variant of calc_length for meta-safe shape computation.
max pooling layer"""
Computes the output length of a 1D convolution / pooling stack using
the same formula as calc_length, but operates purely on Python numbers
so it can be safely used during meta tensor initialization.
"""
add_pad: float = all_paddings - kernel_size add_pad: float = all_paddings - kernel_size
one: float = 1.0 one: float = 1.0
for i in range(repeat_num): length_f: float = float(lengths)
lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one for _ in range(repeat_num):
lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths) length_f = (length_f + add_pad) / stride + one
return lengths.to(dtype=torch.int) length_f = math.ceil(length_f) if ceil_mode else math.floor(length_f)
return int(length_f)
#### multihead attention starts here #### multihead attention starts here