Refactor system architecture (#109)

This commit is contained in:
Woosuk Kwon
2023-05-20 13:06:59 -07:00
committed by GitHub
parent 7297fa6f7c
commit c3442c1f6f
24 changed files with 1017 additions and 1034 deletions

View File

@@ -1,16 +1,13 @@
"""Utilities for selecting and loading models."""
from typing import Optional
import torch
import torch.nn as nn
from transformers import AutoConfig, PretrainedConfig
from transformers import PretrainedConfig
from cacheflow.config import ModelConfig
from cacheflow.model_executor.models import (
GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
from cacheflow.model_executor.utils import get_torch_dtype
from cacheflow.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
"GPT2LMHeadModel": GPT2LMHeadModel,
@@ -19,6 +16,7 @@ _MODEL_REGISTRY = {
"OPTForCausalLM": OPTForCausalLM,
}
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
architectures = getattr(config, "architectures", [])
for arch in architectures:
@@ -30,51 +28,22 @@ def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
)
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
if dtype == "default":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
torch_dtype = get_torch_dtype(dtype)
if torch_dtype != config_dtype and config_dtype != torch.float32:
# TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue.
raise ValueError(
f"Cannot use {torch_dtype} for {config_dtype} model.")
return torch_dtype
def get_model(
model_name: str,
dtype: str,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
) -> nn.Module:
config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
torch.set_default_dtype(torch_dtype)
model_class = _get_model_architecture(config)
def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config)
torch.set_default_dtype(model_config.dtype)
# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(config)
if use_dummy_weights:
model = model_class(model_config.hf_config)
if model_config.use_dummy_weights:
model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
else:
# Load the weights from the cached or downloaded files.
model.load_weights(model_name, cache_dir, use_np_cache)
model.load_weights(
model_config.model, model_config.download_dir,
model_config.use_np_weights)
model = model.cuda()
return model.eval(), torch_dtype
return model.eval()