Refactor sliding window configuration to Transformers best practice (#21927)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-08-10 04:50:48 +01:00
committed by GitHub
parent 2a84fb422f
commit c49848396d
16 changed files with 123 additions and 231 deletions

View File

@@ -40,8 +40,9 @@ from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
maybe_override_with_speculators_target_model, try_get_generation_config,
try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope)
is_interleaved, maybe_override_with_speculators_target_model,
try_get_generation_config, try_get_safetensors_metadata,
try_get_tokenizer_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
# yapf conflicts with isort for this block
@@ -714,53 +715,31 @@ class ModelConfig:
revision=self.revision,
)
# Workaround for Gemma 2 which uses interleaved sliding window
# attention, but it's not specified in its config.
# TODO: remove this when Gemma 2 config updated in HuggingFace.
if self.hf_text_config.model_type == "gemma2":
self.hf_text_config.sliding_window_pattern = 2
# TODO: remove this when Gemma 3n config updated in HuggingFace.
if self.hf_text_config.model_type == "gemma3n_text":
# 4 sliding window attention followed by 1 full attention
self.hf_text_config.sliding_window_pattern = "LLLLG"
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
sliding_window_pattern = getattr(self.hf_text_config,
"sliding_window_pattern", None)
has_interleaved_attention = sliding_window_pattern is not None or (
isinstance(sliding_window, list))
if not self.disable_sliding_window and has_interleaved_attention:
if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND
) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)
logger.warning_once(
"%s has interleaved attention, which is currently not supported by the %s backend. Disabling sliding window and capping the max length to the sliding window size (%d).", # noqa: E501
self.hf_text_config.model_type,
backend,
sliding_window_len_min,
)
self.disable_sliding_window = True
else:
# for a model with interleaved attention,
# the scheduler and the model treat it as full attention
# (i.e., not dropping any tokens outside the window).
# only the attention layer itself is aware of the sliding
# window, and use the window size to compute the attention.
self.hf_text_config.interleaved_sliding_window = sliding_window
if hasattr(self.hf_text_config, "sliding_window"):
delattr(self.hf_text_config, "sliding_window")
sliding_window = None
# Interleaved attention is not supported by some backends in V0
if (not self.disable_sliding_window
and is_interleaved(self.hf_text_config)
and not envs.VLLM_USE_V1
and (backend := envs.VLLM_ATTENTION_BACKEND)
in ("XFORMERS", "FLASHINFER")):
logger.warning_once(
"%s has interleaved attention, which is currently not "
"supported by the %s backend. Disabling sliding window and "
"capping the max length to the sliding window size (%d).",
self.hf_text_config.model_type,
backend,
self.hf_text_config.sliding_window,
)
self.disable_sliding_window = True
self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
self.multimodal_config = self._init_multimodal_config()
if self.disable_sliding_window:
# Set after get_and_verify_max_len to ensure that max_model_len
# can be correctly capped to sliding window size
self.hf_text_config.sliding_window = None
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
@@ -1322,27 +1301,10 @@ class ModelConfig:
if self.use_async_output_proc:
self.use_async_output_proc = False
def get_hf_config_sliding_window(
self) -> Union[Optional[int], list[Optional[int]]]:
"""Get the sliding window size, or None if disabled."""
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present
# and if it's False, return None.
if (hasattr(self.hf_text_config, "use_sliding_window")
and not self.hf_text_config.use_sliding_window):
return None
def get_sliding_window(self) -> Optional[int]:
"""Get the sliding window size from the HF text config if present."""
return getattr(self.hf_text_config, "sliding_window", None)
def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.
if self.disable_sliding_window:
return None
# Otherwise get the value from the hf config.
return self.get_hf_config_sliding_window()
def get_vocab_size(self) -> int:
return getattr(self.hf_text_config, "vocab_size", 0)
@@ -1762,7 +1724,7 @@ class ModelConfig:
tokenizer_config=tokenizer_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(),
sliding_window=self.get_sliding_window(),
spec_target_max_model_len=self.spec_target_max_model_len,
encoder_config=self.encoder_config)
logger.info("Using max model len %s", max_model_len)
@@ -3305,7 +3267,7 @@ def _get_and_verify_max_len(
tokenizer_config: Optional[dict],
max_model_len: Optional[int],
disable_sliding_window: bool,
sliding_window_len: Optional[Union[int, list[Optional[int]]]],
sliding_window: Optional[int],
spec_target_max_model_len: Optional[int] = None,
encoder_config: Optional[Any] = None,
) -> int:
@@ -3344,13 +3306,10 @@ def _get_and_verify_max_len(
# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
if disable_sliding_window and sliding_window_len is not None:
sliding_window_len_min = get_min_sliding_window(sliding_window_len)
max_len_key = "sliding_window" \
if sliding_window_len_min < derived_max_model_len else max_len_key
derived_max_model_len = min(derived_max_model_len,
sliding_window_len_min)
if (disable_sliding_window and sliding_window is not None
and sliding_window < derived_max_model_len):
max_len_key = "sliding_window"
derived_max_model_len = sliding_window
# Consider model_max_length in tokenizer_config
if tokenizer_config:
@@ -3451,14 +3410,6 @@ def _get_and_verify_max_len(
return int(max_model_len)
def get_min_sliding_window(
sliding_window: Union[int, list[Optional[int]]]) -> int:
if isinstance(sliding_window, list):
return min(s for s in sliding_window if s is not None)
return sliding_window
def get_served_model_name(model: str,
served_model_name: Optional[Union[str, list[str]]]):
"""