Support FP32 (#141)

This commit is contained in:
Woosuk Kwon
2023-06-07 00:40:21 -07:00
committed by GitHub
parent 376725ce74
commit e38074b1e6
8 changed files with 65 additions and 54 deletions

View File

@@ -28,9 +28,10 @@ class LLM:
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float16` and `bfloat16`. If `default`, we use the
`torch_dtype` attribute of the model config. If the `torch_dtype`
is `float32`, we use `float16` instead.
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
seed: The seed to initialize the random number generator for sampling.
"""
@@ -38,7 +39,7 @@ class LLM:
self,
model: str,
tensor_parallel_size: int = 1,
dtype: str = "default",
dtype: str = "auto",
seed: int = 0,
**kwargs,
) -> None: