[Bugfix] Fix qwen2.5-vl overflow issue (#13968)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-02-28 01:30:39 +08:00
committed by GitHub
parent 1dd422b64a
commit 7864875879
4 changed files with 22 additions and 15 deletions

View File

@@ -47,7 +47,7 @@ from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
MiniCPMVMultiModalDataParser,
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
_minicpmv_field_config)
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
CPU_DEVICE = torch.device("cpu")
@@ -469,13 +469,8 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
training=self.training)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any()
or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states,
min=-clamp_value,
max=clamp_value)
if hidden_states.dtype == torch.float16:
hidden_states = cast_overflow_tensors(hidden_states)
outputs = (hidden_states, )

View File

@@ -63,7 +63,7 @@ from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
apply_rotary_pos_emb_vision)
from .utils import (AutoWeightsLoader, WeightsMapper,
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import get_vit_attn_backend
@@ -641,6 +641,11 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb)
# For Qwen2.5-VL-3B, float16 will overflow at last block
# for long visual tokens sequences.
if hidden_states.dtype == torch.float16:
hidden_states = cast_overflow_tensors(hidden_states)
# adapter
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)

View File

@@ -641,3 +641,13 @@ def extract_layer_index(layer_name: str) -> int:
assert len(int_vals) == 1, (f"layer name {layer_name} should"
" only contain one integer")
return int_vals[0]
def cast_overflow_tensors(
tensors: torch.Tensor,
offset: float = 1000,
) -> torch.Tensor:
if tensors.isinf().any() or tensors.isnan().any():
clamp_value = torch.finfo(tensors.dtype).max - offset
tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
return tensors

View File

@@ -35,7 +35,8 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .interfaces import SupportsMultiModal, SupportsTranscription
from .utils import AutoWeightsLoader, WeightsMapper, make_layers
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
make_layers)
logger = init_logger(__name__)
@@ -285,11 +286,7 @@ class WhisperEncoderLayer(nn.Module):
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.isinf().any() or hidden_states.isnan().any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states,
min=-clamp_value,
max=clamp_value)
hidden_states = cast_overflow_tensors(hidden_states)
return hidden_states