Fix mistral sliding window parsing (#33521)

Signed-off-by: Andy Lo <andy@mistral.ai>
This commit is contained in:
Andy Lo
2026-02-02 05:08:04 +00:00
committed by GitHub
parent ce88756b96
commit beb8899482
2 changed files with 25 additions and 22 deletions

View File

@@ -225,19 +225,6 @@ class MistralConfigParser(ConfigParserBase):
config = adapt_config_dict(config_dict, defaults=hf_config_dict)
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible
if (sliding_window := getattr(config, "sliding_window", None)) and isinstance(
sliding_window, list
):
pattern_repeats = config.num_hidden_layers // len(sliding_window)
layer_types = sliding_window * pattern_repeats
config.layer_types = [
"full_attention" if layer_type is None else "sliding_attention"
for layer_type in layer_types
]
config.sliding_window = next(filter(None, sliding_window), None)
return config_dict, config

View File

@@ -14,6 +14,7 @@ def adapt_config_dict(
defaults: dict[str, Any],
) -> PretrainedConfig:
config_dict = _remap_general_mistral_args(config_dict)
config_dict = _remap_mistral_sliding_window(config_dict)
if bool(config_dict.get("quantization")):
config_dict = _remap_mistral_quantization_args(config_dict)
@@ -161,6 +162,29 @@ def _remap_general_mistral_args(config: dict) -> dict:
return config
def _remap_mistral_sliding_window(config: dict) -> dict:
# Remap sliding_window (list) -> layer_types (list) + sliding window (int)
# for HF compatibility
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible
if sliding_window := config.get("sliding_window"):
if isinstance(sliding_window, list):
pattern_repeats = config["num_hidden_layers"] // len(sliding_window)
layer_types = sliding_window * pattern_repeats
config["layer_types"] = [
"full_attention" if layer_type is None else "sliding_attention"
for layer_type in layer_types
]
assert len(set(sliding_window) - {None}) <= 1, sliding_window
config["sliding_window"] = next(filter(None, sliding_window), None)
elif isinstance(sliding_window, int) and config.get("layer_types") is None:
config["layer_types"] = ["sliding_attention"] * config["num_hidden_layers"]
else:
raise ValueError(f"Unsupported sliding_window type: {sliding_window}")
return config
def _remap_mistral_quantization_args(config: dict) -> dict:
if config.get("quantization"):
quantization = config.pop("quantization", {})
@@ -195,14 +219,6 @@ def _remap_mistral_audio_args(config: dict) -> dict:
else:
block_pool_size = 1
_maybe_sliding_window = encoder_args.get("ragged_attention", None)
if _maybe_sliding_window is None:
sliding_window = None
elif _maybe_sliding_window.isdigit():
sliding_window = int(_maybe_sliding_window)
else:
raise NotImplementedError(f"Unsupported: {_maybe_sliding_window=}")
architecture = (
"VoxtralRealtimeGeneration"
if encoder_args.get("causal")
@@ -229,7 +245,7 @@ def _remap_mistral_audio_args(config: dict) -> dict:
max_source_positions=encoder_args["max_source_positions"],
is_encoder_decoder=False, # Override WhisperConfig default
is_causal=encoder_args.get("causal", False),
sliding_window=sliding_window,
sliding_window=encoder_args.get("sliding_window", None),
block_pool_size=block_pool_size,
pos_embed=encoder_args.get("pos_embed", "sinusoidal"),
# only needed for RoPE