diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 5094e3fbb..abb290d25 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -225,19 +225,6 @@ class MistralConfigParser(ConfigParserBase): config = adapt_config_dict(config_dict, defaults=hf_config_dict) - # Mistral configs may define sliding_window as list[int]. Convert it - # to int and add the layer_types list[str] to make it HF compatible - if (sliding_window := getattr(config, "sliding_window", None)) and isinstance( - sliding_window, list - ): - pattern_repeats = config.num_hidden_layers // len(sliding_window) - layer_types = sliding_window * pattern_repeats - config.layer_types = [ - "full_attention" if layer_type is None else "sliding_attention" - for layer_type in layer_types - ] - config.sliding_window = next(filter(None, sliding_window), None) - return config_dict, config diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py index d81042aa9..6a9985583 100644 --- a/vllm/transformers_utils/configs/mistral.py +++ b/vllm/transformers_utils/configs/mistral.py @@ -14,6 +14,7 @@ def adapt_config_dict( defaults: dict[str, Any], ) -> PretrainedConfig: config_dict = _remap_general_mistral_args(config_dict) + config_dict = _remap_mistral_sliding_window(config_dict) if bool(config_dict.get("quantization")): config_dict = _remap_mistral_quantization_args(config_dict) @@ -161,6 +162,29 @@ def _remap_general_mistral_args(config: dict) -> dict: return config +def _remap_mistral_sliding_window(config: dict) -> dict: + # Remap sliding_window (list) -> layer_types (list) + sliding window (int) + # for HF compatibility + # Mistral configs may define sliding_window as list[int]. Convert it + # to int and add the layer_types list[str] to make it HF compatible + if sliding_window := config.get("sliding_window"): + if isinstance(sliding_window, list): + pattern_repeats = config["num_hidden_layers"] // len(sliding_window) + layer_types = sliding_window * pattern_repeats + config["layer_types"] = [ + "full_attention" if layer_type is None else "sliding_attention" + for layer_type in layer_types + ] + assert len(set(sliding_window) - {None}) <= 1, sliding_window + config["sliding_window"] = next(filter(None, sliding_window), None) + elif isinstance(sliding_window, int) and config.get("layer_types") is None: + config["layer_types"] = ["sliding_attention"] * config["num_hidden_layers"] + else: + raise ValueError(f"Unsupported sliding_window type: {sliding_window}") + + return config + + def _remap_mistral_quantization_args(config: dict) -> dict: if config.get("quantization"): quantization = config.pop("quantization", {}) @@ -195,14 +219,6 @@ def _remap_mistral_audio_args(config: dict) -> dict: else: block_pool_size = 1 - _maybe_sliding_window = encoder_args.get("ragged_attention", None) - if _maybe_sliding_window is None: - sliding_window = None - elif _maybe_sliding_window.isdigit(): - sliding_window = int(_maybe_sliding_window) - else: - raise NotImplementedError(f"Unsupported: {_maybe_sliding_window=}") - architecture = ( "VoxtralRealtimeGeneration" if encoder_args.get("causal") @@ -229,7 +245,7 @@ def _remap_mistral_audio_args(config: dict) -> dict: max_source_positions=encoder_args["max_source_positions"], is_encoder_decoder=False, # Override WhisperConfig default is_causal=encoder_args.get("causal", False), - sliding_window=sliding_window, + sliding_window=encoder_args.get("sliding_window", None), block_pool_size=block_pool_size, pos_embed=encoder_args.get("pos_embed", "sinusoidal"), # only needed for RoPE