Enable hybrid attention models for Transformers backend (#18494)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-23 04:12:08 +02:00
committed by GitHub
parent c6b636f9fb
commit 4b0da7b60e
4 changed files with 106 additions and 30 deletions

View File

@@ -16,6 +16,7 @@
"""Wrapper around `transformers` models"""
import re
from collections.abc import Iterable
from contextlib import nullcontext
from typing import Literal, Optional, Union
import torch
@@ -110,6 +111,33 @@ def replace_linear_class(
)
class ConfigOverride:
"""Context manager to temporarily override config attributes."""
def __init__(self, config: PretrainedConfig, **kwargs):
self.config = config
self.kwargs = kwargs
self.kwargs_original = {}
self.kwargs_delete = set()
def __enter__(self):
"""Override config attributes."""
for key, value in self.kwargs.items():
if not hasattr(self.config, key):
self.kwargs_delete.add(key)
self.kwargs_original[key] = getattr(self.config, key, None)
setattr(self.config, key, value)
return self.config
def __exit__(self, exc_type, exc_value, traceback):
"""Restore original config attributes."""
for key, value in self.kwargs_original.items():
if key in self.kwargs_delete:
delattr(self.config, key)
else:
setattr(self.config, key, value)
class TransformersModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -135,8 +163,17 @@ class TransformersModel(nn.Module):
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()
# vLLM handles interleaved sliding window attention by creating a new
# interleaved_sliding_window attribute and deleting the sliding_window
# attribute. This breaks the constructors in Transformers so we
# temporarily add the attribute back to construct the model.
config_override = nullcontext()
if hasattr(config, "interleaved_sliding_window"):
config_override = ConfigOverride(
config, sliding_window=config.interleaved_sliding_window)
# Use meta device to delay allocating GPU tensors
with torch.device("meta"):
with torch.device("meta"), config_override:
# FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights.
@@ -262,9 +299,17 @@ class TransformersModel(nn.Module):
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
start, end = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size)
return {
i:
Attention(
attention_instances = {}
for i in range(start, end):
# Handle interleaved sliding window attention
sliding_window = None
if (hasattr(self.config, "interleaved_sliding_window")
and hasattr(self.config, "sliding_window_pattern")
and ((i + 1) % self.config.sliding_window_pattern > 0)):
sliding_window = self.config.interleaved_sliding_window
attention_instances[i] = Attention(
num_heads=num_heads,
head_size=head_size,
# NOTE: We use Llama scale as default, if it's set by
@@ -273,9 +318,9 @@ class TransformersModel(nn.Module):
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
per_layer_sliding_window=sliding_window,
prefix=f"{i}.attn")
for i in range(start, end)
}
return attention_instances
def init_buffers(self, module: nn.Module):
"""