[Misc] Support attention logits soft-capping with flash-attn (#7022)

This commit is contained in:
Woosuk Kwon
2024-08-01 13:14:37 -07:00
committed by GitHub
parent 562e580abc
commit 805a8a75f2
14 changed files with 71 additions and 47 deletions

View File

@@ -90,7 +90,8 @@ class Gemma2Attention(nn.Module):
max_position_embeddings: int,
rope_theta: float,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
attn_logits_soft_cap: Optional[float] = None) -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
@@ -150,7 +151,8 @@ class Gemma2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap)
def forward(
self,
@@ -189,6 +191,7 @@ class Gemma2DecoderLayer(nn.Module):
rope_theta=config.rope_theta,
cache_config=cache_config,
quant_config=quant_config,
attn_logits_soft_cap=config.attn_logit_softcapping,
)
self.hidden_size = config.hidden_size
self.mlp = Gemma2MLP(