Use runtime profiling to replace manual memory analyzers (#81)

This commit is contained in:
Zhuohan Li
2023-05-19 11:35:44 -06:00
committed by GitHub
parent 825d8892b5
commit f756799b84
14 changed files with 211 additions and 478 deletions

View File

@@ -58,7 +58,8 @@ class GPT2Attention(nn.Module):
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, bias=True,
input_is_parallel=True,
perform_initialization=False)
self.attn = GPTCacheFlowAttention(scale=self.scale)
self.attn = GPTCacheFlowAttention(self.num_heads, self.head_dim,
scale=self.scale)
def forward(
self,