[Bugfix] Fix qwen2.5-vl overflow issue (#13968)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -47,7 +47,7 @@ from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
|
||||
MiniCPMVMultiModalDataParser,
|
||||
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
|
||||
_minicpmv_field_config)
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
@@ -469,13 +469,8 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
|
||||
training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any()
|
||||
or torch.isnan(hidden_states).any()):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states,
|
||||
min=-clamp_value,
|
||||
max=clamp_value)
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = cast_overflow_tensors(hidden_states)
|
||||
|
||||
outputs = (hidden_states, )
|
||||
|
||||
|
||||
Reference in New Issue
Block a user