Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support (#11844)

This commit is contained in:
Tao He
2025-05-13 10:52:47 +08:00
committed by GitHub
parent f6518b2b48
commit 60f7624334
17 changed files with 2444 additions and 32 deletions

View File

@@ -23,7 +23,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from typing import Any, Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
@@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
is_pp_missing_parameter,
extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@@ -99,17 +99,20 @@ class Qwen2MLP(nn.Module):
class Qwen2Attention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None:
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str,
Any]] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
@@ -131,6 +134,7 @@ class Qwen2Attention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -155,15 +159,21 @@ class Qwen2Attention(nn.Module):
max_position=max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=attn_type)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
attn_type=attn_type,
prefix=f"{prefix}.attn",
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {})
def forward(
self,
@@ -192,6 +202,9 @@ class Qwen2DecoderLayer(nn.Module):
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
@@ -213,6 +226,7 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,