diff --git a/vllm/model_executor/models/kimi_k25.py b/vllm/model_executor/models/kimi_k25.py index bb9f35bdb..9d287ba9b 100644 --- a/vllm/model_executor/models/kimi_k25.py +++ b/vllm/model_executor/models/kimi_k25.py @@ -23,6 +23,10 @@ from transformers.processing_utils import ProcessorMixin from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( + CompressedTensorsConfig, +) from vllm.model_executor.models.interfaces import ( SupportsMultiModal, SupportsPP, @@ -361,6 +365,7 @@ class KimiK25ForConditionalGeneration( with self._mark_tower_model(vllm_config, "vision_chunk"): self.vision_tower = MoonViT3dPretrainedModel( config.vision_config, + quant_config=self._maybe_ignore_quant_config(quant_config), prefix=maybe_prefix(prefix, "vision_tower"), ) self.vision_tower = self.vision_tower.to( @@ -370,6 +375,7 @@ class KimiK25ForConditionalGeneration( self.mm_projector = KimiK25MultiModalProjector( config=config.vision_config, use_data_parallel=self.use_data_parallel, + quant_config=self._maybe_ignore_quant_config(quant_config), prefix=maybe_prefix(prefix, "mm_projector"), ) self.mm_projector = self.mm_projector.to( @@ -389,6 +395,11 @@ class KimiK25ForConditionalGeneration( ) self.media_placeholder: int = self.config.media_placeholder_token_id + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + if isinstance(quant_config, CompressedTensorsConfig): + return None + return quant_config + def _parse_and_validate_media_input( self, **kwargs: object ) -> KimiK25MediaPixelInputs | None: diff --git a/vllm/model_executor/models/kimi_k25_vit.py b/vllm/model_executor/models/kimi_k25_vit.py index 470311ecc..69524293c 100644 --- a/vllm/model_executor/models/kimi_k25_vit.py +++ b/vllm/model_executor/models/kimi_k25_vit.py @@ -28,6 +28,7 @@ from vllm.model_executor.layers.linear import ( ReplicatedLinear, RowParallelLinear, ) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.models.vision import ( is_vit_use_data_parallel, @@ -304,6 +305,7 @@ class MLP2(nn.Module): dims: list[int], activation, bias: bool = True, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -314,6 +316,7 @@ class MLP2(nn.Module): dims[0], dims[1], bias=bias, + quant_config=quant_config, prefix=maybe_prefix(prefix, "fc0"), disable_tp=self.use_data_parallel, ) @@ -321,6 +324,7 @@ class MLP2(nn.Module): dims[1], dims[2], bias=bias, + quant_config=quant_config, prefix=maybe_prefix(prefix, "fc1"), disable_tp=self.use_data_parallel, ) @@ -341,6 +345,7 @@ class MoonViTEncoderLayer(nn.Module): num_heads: int, hidden_dim: int, mlp_dim: int, + quant_config: QuantizationConfig | None = None, prefix: str = "", *, activation=F.gelu, @@ -362,6 +367,7 @@ class MoonViTEncoderLayer(nn.Module): self.mlp = MLP2( [hidden_dim, mlp_dim, hidden_dim], activation, + quant_config=quant_config, prefix=f"{prefix}.mlp", use_data_parallel=self.use_data_parallel, ) @@ -371,6 +377,7 @@ class MoonViTEncoderLayer(nn.Module): total_num_heads=num_heads, total_num_kv_heads=num_heads, bias=attn_bias, + quant_config=quant_config, prefix=f"{prefix}.wqkv", disable_tp=self.use_data_parallel, ) @@ -378,6 +385,7 @@ class MoonViTEncoderLayer(nn.Module): hidden_dim, hidden_dim, bias=attn_bias, + quant_config=quant_config, prefix=f"{prefix}.wo", disable_tp=self.use_data_parallel, ) @@ -461,6 +469,7 @@ class MoonViT3dEncoder(nn.Module): num_layers: int, block_cfg: dict, video_attn_type: str = "spatial_temporal", + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -476,6 +485,7 @@ class MoonViT3dEncoder(nn.Module): [ MoonViTEncoderLayer( **block_cfg, + quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", ) for layer_idx in range(num_layers) @@ -544,6 +554,7 @@ class MoonViT3dPretrainedModel(nn.Module): def __init__( self, config: KimiK25VisionConfig, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -573,6 +584,7 @@ class MoonViT3dPretrainedModel(nn.Module): "attn_bias": True, }, video_attn_type=config.video_attn_type, + quant_config=quant_config, prefix=maybe_prefix(prefix, "encoder"), ) @@ -646,6 +658,7 @@ class KimiK25MultiModalProjector(nn.Module): self, config: KimiK25VisionConfig, use_data_parallel: bool = False, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -660,12 +673,14 @@ class KimiK25MultiModalProjector(nn.Module): self.hidden_size, self.hidden_size, bias=True, + quant_config=quant_config, prefix=f"{prefix}.linear_1", ) self.linear_2 = ReplicatedLinear( self.hidden_size, config.mm_hidden_size, bias=True, + quant_config=quant_config, prefix=f"{prefix}.linear_2", ) self.act = GELUActivation()