[Bugfix] Add quant_config in ViT of Kimi-K2.5 (#34501)
Signed-off-by: LoganJane <LoganJane73@hotmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -23,6 +23,10 @@ from transformers.processing_utils import ProcessorMixin
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
@@ -361,6 +365,7 @@ class KimiK25ForConditionalGeneration(
|
||||
with self._mark_tower_model(vllm_config, "vision_chunk"):
|
||||
self.vision_tower = MoonViT3dPretrainedModel(
|
||||
config.vision_config,
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
self.vision_tower = self.vision_tower.to(
|
||||
@@ -370,6 +375,7 @@ class KimiK25ForConditionalGeneration(
|
||||
self.mm_projector = KimiK25MultiModalProjector(
|
||||
config=config.vision_config,
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=maybe_prefix(prefix, "mm_projector"),
|
||||
)
|
||||
self.mm_projector = self.mm_projector.to(
|
||||
@@ -389,6 +395,11 @@ class KimiK25ForConditionalGeneration(
|
||||
)
|
||||
self.media_placeholder: int = self.config.media_placeholder_token_id
|
||||
|
||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||
if isinstance(quant_config, CompressedTensorsConfig):
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def _parse_and_validate_media_input(
|
||||
self, **kwargs: object
|
||||
) -> KimiK25MediaPixelInputs | None:
|
||||
|
||||
@@ -28,6 +28,7 @@ from vllm.model_executor.layers.linear import (
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.model_executor.models.vision import (
|
||||
is_vit_use_data_parallel,
|
||||
@@ -304,6 +305,7 @@ class MLP2(nn.Module):
|
||||
dims: list[int],
|
||||
activation,
|
||||
bias: bool = True,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
@@ -314,6 +316,7 @@ class MLP2(nn.Module):
|
||||
dims[0],
|
||||
dims[1],
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "fc0"),
|
||||
disable_tp=self.use_data_parallel,
|
||||
)
|
||||
@@ -321,6 +324,7 @@ class MLP2(nn.Module):
|
||||
dims[1],
|
||||
dims[2],
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "fc1"),
|
||||
disable_tp=self.use_data_parallel,
|
||||
)
|
||||
@@ -341,6 +345,7 @@ class MoonViTEncoderLayer(nn.Module):
|
||||
num_heads: int,
|
||||
hidden_dim: int,
|
||||
mlp_dim: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
activation=F.gelu,
|
||||
@@ -362,6 +367,7 @@ class MoonViTEncoderLayer(nn.Module):
|
||||
self.mlp = MLP2(
|
||||
[hidden_dim, mlp_dim, hidden_dim],
|
||||
activation,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
@@ -371,6 +377,7 @@ class MoonViTEncoderLayer(nn.Module):
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=attn_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wqkv",
|
||||
disable_tp=self.use_data_parallel,
|
||||
)
|
||||
@@ -378,6 +385,7 @@ class MoonViTEncoderLayer(nn.Module):
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
bias=attn_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wo",
|
||||
disable_tp=self.use_data_parallel,
|
||||
)
|
||||
@@ -461,6 +469,7 @@ class MoonViT3dEncoder(nn.Module):
|
||||
num_layers: int,
|
||||
block_cfg: dict,
|
||||
video_attn_type: str = "spatial_temporal",
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -476,6 +485,7 @@ class MoonViT3dEncoder(nn.Module):
|
||||
[
|
||||
MoonViTEncoderLayer(
|
||||
**block_cfg,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
)
|
||||
for layer_idx in range(num_layers)
|
||||
@@ -544,6 +554,7 @@ class MoonViT3dPretrainedModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: KimiK25VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -573,6 +584,7 @@ class MoonViT3dPretrainedModel(nn.Module):
|
||||
"attn_bias": True,
|
||||
},
|
||||
video_attn_type=config.video_attn_type,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "encoder"),
|
||||
)
|
||||
|
||||
@@ -646,6 +658,7 @@ class KimiK25MultiModalProjector(nn.Module):
|
||||
self,
|
||||
config: KimiK25VisionConfig,
|
||||
use_data_parallel: bool = False,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -660,12 +673,14 @@ class KimiK25MultiModalProjector(nn.Module):
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_1",
|
||||
)
|
||||
self.linear_2 = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
config.mm_hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_2",
|
||||
)
|
||||
self.act = GELUActivation()
|
||||
|
||||
Reference in New Issue
Block a user