Fix phi4-mm and remove cuda binding (#35964)
Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
@@ -585,10 +585,9 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
enc_streaming_mask = self._streaming_mask(
|
||||
seq_len, batch_size, self.chunk_size, self.left_chunk
|
||||
)
|
||||
|
||||
if xs_pad.is_cuda:
|
||||
enc_streaming_mask = enc_streaming_mask.cuda()
|
||||
xs_pad = xs_pad.cuda()
|
||||
device = xs_pad.device
|
||||
enc_streaming_mask = enc_streaming_mask.to(device)
|
||||
xs_pad = xs_pad.to(device)
|
||||
|
||||
input_tensor = xs_pad
|
||||
input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
|
||||
@@ -605,8 +604,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
enc_streaming_mask_nc = self._streaming_mask(
|
||||
seq_len, batch_size, chunk_size_nc, left_chunk_nc
|
||||
)
|
||||
if xs_pad.is_cuda:
|
||||
enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
|
||||
if device.type != "cpu":
|
||||
enc_streaming_mask_nc = enc_streaming_mask_nc.to(device)
|
||||
if masks is not None:
|
||||
hs_mask_nc = masks & enc_streaming_mask_nc
|
||||
else:
|
||||
|
||||
@@ -1309,16 +1309,15 @@ class NemoConvSubsampling(torch.nn.Module):
|
||||
raise ValueError(f"Not valid sub-sampling: {subsampling}!")
|
||||
|
||||
if subsampling in ["dw_striding", "striding"]:
|
||||
in_length = torch.tensor(feat_in, dtype=torch.float)
|
||||
out_length = calc_length(
|
||||
lengths=in_length,
|
||||
out_length = calc_length_int(
|
||||
lengths=feat_in,
|
||||
all_paddings=self._left_padding + self._right_padding,
|
||||
kernel_size=self._kernel_size,
|
||||
stride=self._stride,
|
||||
ceil_mode=self._ceil_mode,
|
||||
repeat_num=self._sampling_num,
|
||||
)
|
||||
self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
|
||||
self.out = torch.nn.Linear(conv_channels * out_length, feat_out)
|
||||
self.conv2d_subsampling = True
|
||||
elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
|
||||
self.out = None
|
||||
@@ -1543,22 +1542,27 @@ class NemoConvSubsampling(torch.nn.Module):
|
||||
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
|
||||
|
||||
|
||||
def calc_length(
|
||||
lengths: Tensor,
|
||||
def calc_length_int(
|
||||
lengths: int,
|
||||
all_paddings: int,
|
||||
kernel_size: int,
|
||||
stride: int,
|
||||
ceil_mode: bool,
|
||||
repeat_num: int = 1,
|
||||
) -> Tensor:
|
||||
"""Calculates the output length of a Tensor passed through a convolution or
|
||||
max pooling layer"""
|
||||
) -> int:
|
||||
"""Integer-only variant of calc_length for meta-safe shape computation.
|
||||
|
||||
Computes the output length of a 1D convolution / pooling stack using
|
||||
the same formula as calc_length, but operates purely on Python numbers
|
||||
so it can be safely used during meta tensor initialization.
|
||||
"""
|
||||
add_pad: float = all_paddings - kernel_size
|
||||
one: float = 1.0
|
||||
for i in range(repeat_num):
|
||||
lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
|
||||
lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths)
|
||||
return lengths.to(dtype=torch.int)
|
||||
length_f: float = float(lengths)
|
||||
for _ in range(repeat_num):
|
||||
length_f = (length_f + add_pad) / stride + one
|
||||
length_f = math.ceil(length_f) if ceil_mode else math.floor(length_f)
|
||||
return int(length_f)
|
||||
|
||||
|
||||
#### multihead attention starts here
|
||||
|
||||
Reference in New Issue
Block a user