Support bfloat16 data type (#54)
This commit is contained in:
@@ -213,8 +213,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument('--use-np-cache', action='store_true',
|
||||
help='save a numpy copy of model weights for faster loading')
|
||||
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
|
||||
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
||||
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
|
||||
# NOTE(woosuk): FlashAttention does not support float32.
|
||||
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
|
||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
||||
|
||||
Reference in New Issue
Block a user