Implement custom kernel for LLaMA rotary embedding (#14)

This commit is contained in:
Woosuk Kwon
2023-03-30 11:04:21 -07:00
committed by GitHub
parent 80a2f812f1
commit 88c0268a18
10 changed files with 318 additions and 69 deletions

View File

@@ -165,8 +165,7 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
self.head_size = config.hidden_size // self.num_heads
self.ffn_size = config.intermediate_size
self.vocab_size = config.vocab_size
# FIXME
self.max_position = 2048
self.max_position = 8192
def _get_param_size(self) -> int:
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size