Use runtime profiling to replace manual memory analyzers (#81)
This commit is contained in:
@@ -5,9 +5,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
from cacheflow.model_executor.memory_analyzer import (
|
||||
CacheFlowMemoryAnalyzer, GPT2MemoryAnalyzer, GPTNeoXMemoryAnalyzer,
|
||||
LlamaMemoryAnalyzer, OPTMemoryAnalyzer)
|
||||
from cacheflow.model_executor.models import (
|
||||
GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
|
||||
from cacheflow.model_executor.utils import get_torch_dtype
|
||||
@@ -22,14 +19,6 @@ _MODEL_REGISTRY = {
|
||||
"OPTForCausalLM": OPTForCausalLM,
|
||||
}
|
||||
|
||||
_MEMORY_ANALYZERS = {
|
||||
"GPT2LMHeadModel": GPT2MemoryAnalyzer,
|
||||
"GPTNeoXForCausalLM": GPTNeoXMemoryAnalyzer,
|
||||
"LlamaForCausalLM": LlamaMemoryAnalyzer,
|
||||
"OPTForCausalLM": OPTMemoryAnalyzer,
|
||||
}
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
@@ -41,17 +30,6 @@ def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
|
||||
)
|
||||
|
||||
|
||||
def _get_memory_analyzer(config: PretrainedConfig) -> CacheFlowMemoryAnalyzer:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _MEMORY_ANALYZERS:
|
||||
return _MEMORY_ANALYZERS[arch]
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {list(_MEMORY_ANALYZERS.keys())}"
|
||||
)
|
||||
|
||||
|
||||
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
@@ -100,18 +78,3 @@ def get_model(
|
||||
model = model.cuda()
|
||||
return model.eval(), torch_dtype
|
||||
|
||||
|
||||
def get_memory_analyzer(
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
dtype: str,
|
||||
gpu_memory: int,
|
||||
cpu_memory: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
) -> CacheFlowMemoryAnalyzer:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
torch_dtype = _get_dtype(config, dtype)
|
||||
memory_analyzer = _get_memory_analyzer(config)
|
||||
return memory_analyzer(
|
||||
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
|
||||
tensor_parallel_size)
|
||||
|
||||
Reference in New Issue
Block a user