[Bugfix] Fix qwen2.5-vl overflow issue (#13968)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -47,7 +47,7 @@ from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
|
||||
MiniCPMVMultiModalDataParser,
|
||||
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
|
||||
_minicpmv_field_config)
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
@@ -469,13 +469,8 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
|
||||
training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any()
|
||||
or torch.isnan(hidden_states).any()):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states,
|
||||
min=-clamp_value,
|
||||
max=clamp_value)
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = cast_overflow_tensors(hidden_states)
|
||||
|
||||
outputs = (hidden_states, )
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
|
||||
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
|
||||
apply_rotary_pos_emb_vision)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import get_vit_attn_backend
|
||||
@@ -641,6 +641,11 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
rotary_pos_emb=rotary_pos_emb)
|
||||
|
||||
# For Qwen2.5-VL-3B, float16 will overflow at last block
|
||||
# for long visual tokens sequences.
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = cast_overflow_tensors(hidden_states)
|
||||
|
||||
# adapter
|
||||
hidden_states = self.merger(hidden_states)
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
|
||||
@@ -641,3 +641,13 @@ def extract_layer_index(layer_name: str) -> int:
|
||||
assert len(int_vals) == 1, (f"layer name {layer_name} should"
|
||||
" only contain one integer")
|
||||
return int_vals[0]
|
||||
|
||||
|
||||
def cast_overflow_tensors(
|
||||
tensors: torch.Tensor,
|
||||
offset: float = 1000,
|
||||
) -> torch.Tensor:
|
||||
if tensors.isinf().any() or tensors.isnan().any():
|
||||
clamp_value = torch.finfo(tensors.dtype).max - offset
|
||||
tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
|
||||
return tensors
|
||||
|
||||
@@ -35,7 +35,8 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsTranscription
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, make_layers
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
||||
make_layers)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -285,11 +286,7 @@ class WhisperEncoderLayer(nn.Module):
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
if hidden_states.isinf().any() or hidden_states.isnan().any():
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states,
|
||||
min=-clamp_value,
|
||||
max=clamp_value)
|
||||
hidden_states = cast_overflow_tensors(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user