Fix phi4-mm and remove cuda binding (#35964)

Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma
2026-03-05 01:08:05 +08:00
committed by GitHub
parent e86221deb6
commit 58cfe0dc44
2 changed files with 22 additions and 19 deletions

View File

@@ -585,10 +585,9 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
enc_streaming_mask = self._streaming_mask(
seq_len, batch_size, self.chunk_size, self.left_chunk
)
if xs_pad.is_cuda:
enc_streaming_mask = enc_streaming_mask.cuda()
xs_pad = xs_pad.cuda()
device = xs_pad.device
enc_streaming_mask = enc_streaming_mask.to(device)
xs_pad = xs_pad.to(device)
input_tensor = xs_pad
input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
@@ -605,8 +604,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
enc_streaming_mask_nc = self._streaming_mask(
seq_len, batch_size, chunk_size_nc, left_chunk_nc
)
if xs_pad.is_cuda:
enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
if device.type != "cpu":
enc_streaming_mask_nc = enc_streaming_mask_nc.to(device)
if masks is not None:
hs_mask_nc = masks & enc_streaming_mask_nc
else:

View File

@@ -1309,16 +1309,15 @@ class NemoConvSubsampling(torch.nn.Module):
raise ValueError(f"Not valid sub-sampling: {subsampling}!")
if subsampling in ["dw_striding", "striding"]:
in_length = torch.tensor(feat_in, dtype=torch.float)
out_length = calc_length(
lengths=in_length,
out_length = calc_length_int(
lengths=feat_in,
all_paddings=self._left_padding + self._right_padding,
kernel_size=self._kernel_size,
stride=self._stride,
ceil_mode=self._ceil_mode,
repeat_num=self._sampling_num,
)
self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
self.out = torch.nn.Linear(conv_channels * out_length, feat_out)
self.conv2d_subsampling = True
elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
self.out = None
@@ -1543,22 +1542,27 @@ class NemoConvSubsampling(torch.nn.Module):
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
def calc_length(
lengths: Tensor,
def calc_length_int(
lengths: int,
all_paddings: int,
kernel_size: int,
stride: int,
ceil_mode: bool,
repeat_num: int = 1,
) -> Tensor:
"""Calculates the output length of a Tensor passed through a convolution or
max pooling layer"""
) -> int:
"""Integer-only variant of calc_length for meta-safe shape computation.
Computes the output length of a 1D convolution / pooling stack using
the same formula as calc_length, but operates purely on Python numbers
so it can be safely used during meta tensor initialization.
"""
add_pad: float = all_paddings - kernel_size
one: float = 1.0
for i in range(repeat_num):
lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths)
return lengths.to(dtype=torch.int)
length_f: float = float(lengths)
for _ in range(repeat_num):
length_f = (length_f + add_pad) / stride + one
length_f = math.ceil(length_f) if ceil_mode else math.floor(length_f)
return int(length_f)
#### multihead attention starts here