Support FP32 (#141)
This commit is contained in:
@@ -28,9 +28,10 @@ class LLM:
|
||||
tensor_parallel_size: The number of GPUs to use for distributed
|
||||
execution with tensor parallelism.
|
||||
dtype: The data type for the model weights and activations. Currently,
|
||||
we support `float16` and `bfloat16`. If `default`, we use the
|
||||
`torch_dtype` attribute of the model config. If the `torch_dtype`
|
||||
is `float32`, we use `float16` instead.
|
||||
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
|
||||
the `torch_dtype` attribute specified in the model config file.
|
||||
However, if the `torch_dtype` in the config is `float32`, we will
|
||||
use `float16` instead.
|
||||
seed: The seed to initialize the random number generator for sampling.
|
||||
"""
|
||||
|
||||
@@ -38,7 +39,7 @@ class LLM:
|
||||
self,
|
||||
model: str,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "default",
|
||||
dtype: str = "auto",
|
||||
seed: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user