Enable hybrid attention models for Transformers backend (#18494)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
"""Wrapper around `transformers` models"""
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from contextlib import nullcontext
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -110,6 +111,33 @@ def replace_linear_class(
|
||||
)
|
||||
|
||||
|
||||
class ConfigOverride:
|
||||
"""Context manager to temporarily override config attributes."""
|
||||
|
||||
def __init__(self, config: PretrainedConfig, **kwargs):
|
||||
self.config = config
|
||||
self.kwargs = kwargs
|
||||
self.kwargs_original = {}
|
||||
self.kwargs_delete = set()
|
||||
|
||||
def __enter__(self):
|
||||
"""Override config attributes."""
|
||||
for key, value in self.kwargs.items():
|
||||
if not hasattr(self.config, key):
|
||||
self.kwargs_delete.add(key)
|
||||
self.kwargs_original[key] = getattr(self.config, key, None)
|
||||
setattr(self.config, key, value)
|
||||
return self.config
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Restore original config attributes."""
|
||||
for key, value in self.kwargs_original.items():
|
||||
if key in self.kwargs_delete:
|
||||
delattr(self.config, key)
|
||||
else:
|
||||
setattr(self.config, key, value)
|
||||
|
||||
|
||||
class TransformersModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@@ -135,8 +163,17 @@ class TransformersModel(nn.Module):
|
||||
self.pp_rank = self.pp_group.rank_in_group
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# vLLM handles interleaved sliding window attention by creating a new
|
||||
# interleaved_sliding_window attribute and deleting the sliding_window
|
||||
# attribute. This breaks the constructors in Transformers so we
|
||||
# temporarily add the attribute back to construct the model.
|
||||
config_override = nullcontext()
|
||||
if hasattr(config, "interleaved_sliding_window"):
|
||||
config_override = ConfigOverride(
|
||||
config, sliding_window=config.interleaved_sliding_window)
|
||||
|
||||
# Use meta device to delay allocating GPU tensors
|
||||
with torch.device("meta"):
|
||||
with torch.device("meta"), config_override:
|
||||
# FIXME(Isotr0py): We need to refactor this part in the future to
|
||||
# avoid registering an extra model layer, otherwise we will need a
|
||||
# weights mapper to rename weights.
|
||||
@@ -262,9 +299,17 @@ class TransformersModel(nn.Module):
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
start, end = get_pp_indices(self.config.num_hidden_layers,
|
||||
self.pp_rank, self.pp_size)
|
||||
return {
|
||||
i:
|
||||
Attention(
|
||||
|
||||
attention_instances = {}
|
||||
for i in range(start, end):
|
||||
# Handle interleaved sliding window attention
|
||||
sliding_window = None
|
||||
if (hasattr(self.config, "interleaved_sliding_window")
|
||||
and hasattr(self.config, "sliding_window_pattern")
|
||||
and ((i + 1) % self.config.sliding_window_pattern > 0)):
|
||||
sliding_window = self.config.interleaved_sliding_window
|
||||
|
||||
attention_instances[i] = Attention(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
# NOTE: We use Llama scale as default, if it's set by
|
||||
@@ -273,9 +318,9 @@ class TransformersModel(nn.Module):
|
||||
num_kv_heads=num_kv_heads,
|
||||
cache_config=self.cache_config,
|
||||
quant_config=self.quant_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
prefix=f"{i}.attn")
|
||||
for i in range(start, end)
|
||||
}
|
||||
return attention_instances
|
||||
|
||||
def init_buffers(self, module: nn.Module):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user