[torch.compile] support all attention backends (#10558)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -52,7 +52,8 @@ class Starcoder2Attention(nn.Module):
|
||||
def __init__(self,
|
||||
config: Starcoder2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -105,7 +106,8 @@ class Starcoder2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -154,12 +156,14 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
config: Starcoder2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Starcoder2Attention(config,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.mlp = Starcoder2MLP(config, quant_config=quant_config)
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.norm_epsilon)
|
||||
@@ -213,7 +217,8 @@ class Starcoder2Model(nn.Module):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Starcoder2DecoderLayer(
|
||||
config, cache_config, quant_config=quant_config),
|
||||
config, cache_config, quant_config=quant_config, prefix=prefix
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
||||
|
||||
Reference in New Issue
Block a user