Support FP32 (#141)
This commit is contained in:
@@ -164,7 +164,7 @@ def _get_and_verify_dtype(
|
||||
config_dtype = torch.float32
|
||||
|
||||
dtype = dtype.lower()
|
||||
if dtype == "default":
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
# Following the common practice, we use float16 for float32 models.
|
||||
torch_dtype = torch.float16
|
||||
@@ -184,9 +184,8 @@ def _get_and_verify_dtype(
|
||||
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is not allowed.
|
||||
raise ValueError(
|
||||
f"Cannot use {torch_dtype} for {config_dtype} model.")
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warn(f"Casting {config_dtype} to {torch_dtype}.")
|
||||
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16:
|
||||
|
||||
@@ -28,9 +28,10 @@ class LLM:
|
||||
tensor_parallel_size: The number of GPUs to use for distributed
|
||||
execution with tensor parallelism.
|
||||
dtype: The data type for the model weights and activations. Currently,
|
||||
we support `float16` and `bfloat16`. If `default`, we use the
|
||||
`torch_dtype` attribute of the model config. If the `torch_dtype`
|
||||
is `float32`, we use `float16` instead.
|
||||
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
|
||||
the `torch_dtype` attribute specified in the model config file.
|
||||
However, if the `torch_dtype` in the config is `float32`, we will
|
||||
use `float16` instead.
|
||||
seed: The seed to initialize the random number generator for sampling.
|
||||
"""
|
||||
|
||||
@@ -38,7 +39,7 @@ class LLM:
|
||||
self,
|
||||
model: str,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "default",
|
||||
dtype: str = "auto",
|
||||
seed: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
||||
@@ -10,7 +10,7 @@ from cacheflow import cache_ops
|
||||
from cacheflow import pos_encoding_ops
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
|
||||
|
||||
|
||||
class GPTCacheFlowAttention(nn.Module):
|
||||
@@ -49,10 +49,8 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
||||
|
||||
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
||||
raise ValueError(f'head_size ({self.head_size}) is not supported by '
|
||||
'the single_query_cached_kv_attention kernel. '
|
||||
'Use one of the following head sizes: '
|
||||
f'{_SUPPORTED_HEAD_SIZES}.')
|
||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
|
||||
@@ -13,7 +13,7 @@ class ServerArgs:
|
||||
download_dir: Optional[str] = None
|
||||
use_np_weights: bool = False
|
||||
use_dummy_weights: bool = False
|
||||
dtype: str = "default"
|
||||
dtype: str = "auto"
|
||||
seed: int = 0
|
||||
worker_use_ray: bool = False
|
||||
pipeline_parallel_size: int = 1
|
||||
@@ -49,9 +49,9 @@ class ServerArgs:
|
||||
help='use dummy values for model weights')
|
||||
# TODO(woosuk): Support FP32.
|
||||
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
|
||||
choices=['default', 'half', 'bfloat16'],
|
||||
choices=['auto', 'half', 'bfloat16', 'float'],
|
||||
help='data type for model weights and activations. '
|
||||
'The "default" option will use FP16 precision '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
# Parallel arguments
|
||||
@@ -67,7 +67,7 @@ class ServerArgs:
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size', type=int,
|
||||
default=ServerArgs.block_size,
|
||||
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
choices=[8, 16, 32],
|
||||
help='token block size')
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
|
||||
|
||||
Reference in New Issue
Block a user