Fix phi4-mm and remove cuda binding (#35964)
Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
@@ -585,10 +585,9 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
|||||||
enc_streaming_mask = self._streaming_mask(
|
enc_streaming_mask = self._streaming_mask(
|
||||||
seq_len, batch_size, self.chunk_size, self.left_chunk
|
seq_len, batch_size, self.chunk_size, self.left_chunk
|
||||||
)
|
)
|
||||||
|
device = xs_pad.device
|
||||||
if xs_pad.is_cuda:
|
enc_streaming_mask = enc_streaming_mask.to(device)
|
||||||
enc_streaming_mask = enc_streaming_mask.cuda()
|
xs_pad = xs_pad.to(device)
|
||||||
xs_pad = xs_pad.cuda()
|
|
||||||
|
|
||||||
input_tensor = xs_pad
|
input_tensor = xs_pad
|
||||||
input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
|
input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
|
||||||
@@ -605,8 +604,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
|||||||
enc_streaming_mask_nc = self._streaming_mask(
|
enc_streaming_mask_nc = self._streaming_mask(
|
||||||
seq_len, batch_size, chunk_size_nc, left_chunk_nc
|
seq_len, batch_size, chunk_size_nc, left_chunk_nc
|
||||||
)
|
)
|
||||||
if xs_pad.is_cuda:
|
if device.type != "cpu":
|
||||||
enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
|
enc_streaming_mask_nc = enc_streaming_mask_nc.to(device)
|
||||||
if masks is not None:
|
if masks is not None:
|
||||||
hs_mask_nc = masks & enc_streaming_mask_nc
|
hs_mask_nc = masks & enc_streaming_mask_nc
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1309,16 +1309,15 @@ class NemoConvSubsampling(torch.nn.Module):
|
|||||||
raise ValueError(f"Not valid sub-sampling: {subsampling}!")
|
raise ValueError(f"Not valid sub-sampling: {subsampling}!")
|
||||||
|
|
||||||
if subsampling in ["dw_striding", "striding"]:
|
if subsampling in ["dw_striding", "striding"]:
|
||||||
in_length = torch.tensor(feat_in, dtype=torch.float)
|
out_length = calc_length_int(
|
||||||
out_length = calc_length(
|
lengths=feat_in,
|
||||||
lengths=in_length,
|
|
||||||
all_paddings=self._left_padding + self._right_padding,
|
all_paddings=self._left_padding + self._right_padding,
|
||||||
kernel_size=self._kernel_size,
|
kernel_size=self._kernel_size,
|
||||||
stride=self._stride,
|
stride=self._stride,
|
||||||
ceil_mode=self._ceil_mode,
|
ceil_mode=self._ceil_mode,
|
||||||
repeat_num=self._sampling_num,
|
repeat_num=self._sampling_num,
|
||||||
)
|
)
|
||||||
self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
|
self.out = torch.nn.Linear(conv_channels * out_length, feat_out)
|
||||||
self.conv2d_subsampling = True
|
self.conv2d_subsampling = True
|
||||||
elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
|
elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
|
||||||
self.out = None
|
self.out = None
|
||||||
@@ -1543,22 +1542,27 @@ class NemoConvSubsampling(torch.nn.Module):
|
|||||||
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
|
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
|
||||||
|
|
||||||
|
|
||||||
def calc_length(
|
def calc_length_int(
|
||||||
lengths: Tensor,
|
lengths: int,
|
||||||
all_paddings: int,
|
all_paddings: int,
|
||||||
kernel_size: int,
|
kernel_size: int,
|
||||||
stride: int,
|
stride: int,
|
||||||
ceil_mode: bool,
|
ceil_mode: bool,
|
||||||
repeat_num: int = 1,
|
repeat_num: int = 1,
|
||||||
) -> Tensor:
|
) -> int:
|
||||||
"""Calculates the output length of a Tensor passed through a convolution or
|
"""Integer-only variant of calc_length for meta-safe shape computation.
|
||||||
max pooling layer"""
|
|
||||||
|
Computes the output length of a 1D convolution / pooling stack using
|
||||||
|
the same formula as calc_length, but operates purely on Python numbers
|
||||||
|
so it can be safely used during meta tensor initialization.
|
||||||
|
"""
|
||||||
add_pad: float = all_paddings - kernel_size
|
add_pad: float = all_paddings - kernel_size
|
||||||
one: float = 1.0
|
one: float = 1.0
|
||||||
for i in range(repeat_num):
|
length_f: float = float(lengths)
|
||||||
lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
|
for _ in range(repeat_num):
|
||||||
lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths)
|
length_f = (length_f + add_pad) / stride + one
|
||||||
return lengths.to(dtype=torch.int)
|
length_f = math.ceil(length_f) if ceil_mode else math.floor(length_f)
|
||||||
|
return int(length_f)
|
||||||
|
|
||||||
|
|
||||||
#### multihead attention starts here
|
#### multihead attention starts here
|
||||||
|
|||||||
Reference in New Issue
Block a user