[torch.compile] support all attention backends (#10558)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-22 14:04:42 -08:00
committed by GitHub
parent db100c5cde
commit eebad39f26
77 changed files with 876 additions and 648 deletions

View File

@@ -84,6 +84,7 @@ class FalconAttention(nn.Module):
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
@@ -158,7 +159,8 @@ class FalconAttention(nn.Module):
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
@@ -171,14 +173,16 @@ class FalconAttention(nn.Module):
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
else:
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@@ -241,12 +245,16 @@ class FalconDecoderLayer(nn.Module):
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config, cache_config,
quant_config)
self.self_attention = FalconAttention(
config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attention")
self.mlp = FalconMLP(config, quant_config)
self.config = config
@@ -357,8 +365,8 @@ class FalconModel(nn.Module):
# Transformer blocks
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: FalconDecoderLayer(config, cache_config,
quant_config),
lambda prefix: FalconDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h")
# Final Layer Norm