Add support for GPT-2 (#60)
This commit is contained in:
@@ -5,9 +5,11 @@ import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
|
||||
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import GPTNeoXMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
|
||||
from cacheflow.models.gpt2 import GPT2LMHeadModel
|
||||
from cacheflow.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from cacheflow.models.llama import LlamaForCausalLM
|
||||
from cacheflow.models.opt import OPTForCausalLM
|
||||
@@ -15,6 +17,7 @@ from cacheflow.models.utils import get_torch_dtype
|
||||
|
||||
|
||||
_MODELS = {
|
||||
'gpt2': GPT2LMHeadModel,
|
||||
'llama': LlamaForCausalLM,
|
||||
'opt': OPTForCausalLM,
|
||||
'stablelm': GPTNeoXForCausalLM,
|
||||
@@ -22,6 +25,7 @@ _MODELS = {
|
||||
}
|
||||
|
||||
_MEMORY_ANALYZERS = {
|
||||
'gpt2': GPT2MemoryAnalyzer,
|
||||
'llama': LlamaMemoryAnalyzer,
|
||||
'opt': OPTMemoryAnalyzer,
|
||||
'stablelm': GPTNeoXMemoryAnalyzer,
|
||||
|
||||
Reference in New Issue
Block a user