Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support (#11844)
This commit is contained in:
@@ -175,6 +175,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -198,6 +199,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.dual_chunk_attention_config = dual_chunk_attention_config
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@@ -221,14 +223,20 @@ class Qwen2MoeAttention(nn.Module):
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
**{
|
||||
"layer_idx": extract_layer_index(prefix),
|
||||
"dual_chunk_attention_config": dual_chunk_attention_config,
|
||||
} if dual_chunk_attention_config else {})
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -256,6 +264,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
dual_chunk_attention_config = getattr(config,
|
||||
"dual_chunk_attention_config",
|
||||
None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = Qwen2MoeAttention(
|
||||
@@ -268,6 +279,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
|
||||
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
||||
|
||||
Reference in New Issue
Block a user