Fix mistral sliding window parsing (#33521)
Signed-off-by: Andy Lo <andy@mistral.ai>
This commit is contained in:
@@ -225,19 +225,6 @@ class MistralConfigParser(ConfigParserBase):
|
||||
|
||||
config = adapt_config_dict(config_dict, defaults=hf_config_dict)
|
||||
|
||||
# Mistral configs may define sliding_window as list[int]. Convert it
|
||||
# to int and add the layer_types list[str] to make it HF compatible
|
||||
if (sliding_window := getattr(config, "sliding_window", None)) and isinstance(
|
||||
sliding_window, list
|
||||
):
|
||||
pattern_repeats = config.num_hidden_layers // len(sliding_window)
|
||||
layer_types = sliding_window * pattern_repeats
|
||||
config.layer_types = [
|
||||
"full_attention" if layer_type is None else "sliding_attention"
|
||||
for layer_type in layer_types
|
||||
]
|
||||
config.sliding_window = next(filter(None, sliding_window), None)
|
||||
|
||||
return config_dict, config
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ def adapt_config_dict(
|
||||
defaults: dict[str, Any],
|
||||
) -> PretrainedConfig:
|
||||
config_dict = _remap_general_mistral_args(config_dict)
|
||||
config_dict = _remap_mistral_sliding_window(config_dict)
|
||||
|
||||
if bool(config_dict.get("quantization")):
|
||||
config_dict = _remap_mistral_quantization_args(config_dict)
|
||||
@@ -161,6 +162,29 @@ def _remap_general_mistral_args(config: dict) -> dict:
|
||||
return config
|
||||
|
||||
|
||||
def _remap_mistral_sliding_window(config: dict) -> dict:
|
||||
# Remap sliding_window (list) -> layer_types (list) + sliding window (int)
|
||||
# for HF compatibility
|
||||
# Mistral configs may define sliding_window as list[int]. Convert it
|
||||
# to int and add the layer_types list[str] to make it HF compatible
|
||||
if sliding_window := config.get("sliding_window"):
|
||||
if isinstance(sliding_window, list):
|
||||
pattern_repeats = config["num_hidden_layers"] // len(sliding_window)
|
||||
layer_types = sliding_window * pattern_repeats
|
||||
config["layer_types"] = [
|
||||
"full_attention" if layer_type is None else "sliding_attention"
|
||||
for layer_type in layer_types
|
||||
]
|
||||
assert len(set(sliding_window) - {None}) <= 1, sliding_window
|
||||
config["sliding_window"] = next(filter(None, sliding_window), None)
|
||||
elif isinstance(sliding_window, int) and config.get("layer_types") is None:
|
||||
config["layer_types"] = ["sliding_attention"] * config["num_hidden_layers"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported sliding_window type: {sliding_window}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _remap_mistral_quantization_args(config: dict) -> dict:
|
||||
if config.get("quantization"):
|
||||
quantization = config.pop("quantization", {})
|
||||
@@ -195,14 +219,6 @@ def _remap_mistral_audio_args(config: dict) -> dict:
|
||||
else:
|
||||
block_pool_size = 1
|
||||
|
||||
_maybe_sliding_window = encoder_args.get("ragged_attention", None)
|
||||
if _maybe_sliding_window is None:
|
||||
sliding_window = None
|
||||
elif _maybe_sliding_window.isdigit():
|
||||
sliding_window = int(_maybe_sliding_window)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported: {_maybe_sliding_window=}")
|
||||
|
||||
architecture = (
|
||||
"VoxtralRealtimeGeneration"
|
||||
if encoder_args.get("causal")
|
||||
@@ -229,7 +245,7 @@ def _remap_mistral_audio_args(config: dict) -> dict:
|
||||
max_source_positions=encoder_args["max_source_positions"],
|
||||
is_encoder_decoder=False, # Override WhisperConfig default
|
||||
is_causal=encoder_args.get("causal", False),
|
||||
sliding_window=sliding_window,
|
||||
sliding_window=encoder_args.get("sliding_window", None),
|
||||
block_pool_size=block_pool_size,
|
||||
pos_embed=encoder_args.get("pos_embed", "sinusoidal"),
|
||||
# only needed for RoPE
|
||||
|
||||
Reference in New Issue
Block a user