Add support for GPT-2 (#60)

This commit is contained in:
Woosuk Kwon
2023-05-04 02:59:56 -07:00
committed by GitHub
parent 130d5fd8c7
commit e548c1488a
7 changed files with 350 additions and 8 deletions

View File

@@ -5,9 +5,11 @@ import torch.nn as nn
from transformers import AutoConfig
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer
from cacheflow.models.memory_analyzer import GPTNeoXMemoryAnalyzer
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
from cacheflow.models.gpt2 import GPT2LMHeadModel
from cacheflow.models.gpt_neox import GPTNeoXForCausalLM
from cacheflow.models.llama import LlamaForCausalLM
from cacheflow.models.opt import OPTForCausalLM
@@ -15,6 +17,7 @@ from cacheflow.models.utils import get_torch_dtype
_MODELS = {
'gpt2': GPT2LMHeadModel,
'llama': LlamaForCausalLM,
'opt': OPTForCausalLM,
'stablelm': GPTNeoXForCausalLM,
@@ -22,6 +25,7 @@ _MODELS = {
}
_MEMORY_ANALYZERS = {
'gpt2': GPT2MemoryAnalyzer,
'llama': LlamaMemoryAnalyzer,
'opt': OPTMemoryAnalyzer,
'stablelm': GPTNeoXMemoryAnalyzer,