Use runtime profiling to replace manual memory analyzers (#81)
This commit is contained in:
@@ -58,7 +58,8 @@ class GPT2Attention(nn.Module):
|
||||
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = GPTCacheFlowAttention(scale=self.scale)
|
||||
self.attn = GPTCacheFlowAttention(self.num_heads, self.head_dim,
|
||||
scale=self.scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user