Enable hybrid attention models for Transformers backend (#18494)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-23 04:12:08 +02:00
committed by GitHub
parent c6b636f9fb
commit 4b0da7b60e
4 changed files with 106 additions and 30 deletions

View File

@@ -533,13 +533,17 @@ class ModelConfig:
self.model, hf_token=self.hf_token, revision=self.revision)
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or
(self.hf_text_config.model_type in interleaved_attn_models))
# Workaround for Gemma 2 which uses interleaved sliding window
# attention, but it's not specified in its config. TODO: remove this
# when Gemma 2 is fixed in Transformers.
if self.hf_text_config.model_type == "gemma2":
self.hf_text_config.sliding_window_pattern = 2
if (not self.disable_sliding_window and has_interleaved_attention):
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
sliding_window_pattern = getattr(self.hf_text_config,
"sliding_window_pattern", None)
if not (self.disable_sliding_window or sliding_window_pattern is None):
if (backend :=
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window(
@@ -1037,8 +1041,7 @@ class ModelConfig:
if self.use_async_output_proc:
self.use_async_output_proc = False
def get_hf_config_sliding_window(
self) -> Union[Optional[int], list[Optional[int]]]:
def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled."""
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
@@ -1049,7 +1052,7 @@ class ModelConfig:
return None
return getattr(self.hf_text_config, "sliding_window", None)
def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
def get_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.