Support FP32 (#141)

This commit is contained in:
Woosuk Kwon
2023-06-07 00:40:21 -07:00
committed by GitHub
parent 376725ce74
commit e38074b1e6
8 changed files with 65 additions and 54 deletions

View File

@@ -370,9 +370,11 @@ void single_query_cached_kv_attention_launcher(
dim3 block(NUM_THREADS);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
case 32:
LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
break;
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
// 32, 160, 192, 256.
// case 32:
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
// break;
case 64:
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
break;
@@ -385,15 +387,15 @@ void single_query_cached_kv_attention_launcher(
case 128:
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
break;
case 160:
LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
break;
case 192:
LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
break;
case 256:
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
break;
// case 160:
// LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
// break;
// case 192:
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
// break;
// case 256:
// LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
// break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
@@ -411,17 +413,19 @@ void single_query_cached_kv_attention_launcher(
context_lens, \
max_context_len);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 1: \
CALL_KERNEL_LAUNCHER(T, 1); \
break; \
case 2: \
CALL_KERNEL_LAUNCHER(T, 2); \
break; \
case 4: \
CALL_KERNEL_LAUNCHER(T, 4); \
break; \
/* case 1: */ \
/* CALL_KERNEL_LAUNCHER(T, 1); */ \
/* break; */ \
/* case 2: */ \
/* CALL_KERNEL_LAUNCHER(T, 2); */ \
/* break; */ \
/* case 4: */ \
/* CALL_KERNEL_LAUNCHER(T, 4); */ \
/* break; */ \
case 8: \
CALL_KERNEL_LAUNCHER(T, 8); \
break; \
@@ -431,15 +435,15 @@ void single_query_cached_kv_attention_launcher(
case 32: \
CALL_KERNEL_LAUNCHER(T, 32); \
break; \
case 64: \
CALL_KERNEL_LAUNCHER(T, 64); \
break; \
case 128: \
CALL_KERNEL_LAUNCHER(T, 128); \
break; \
case 256: \
CALL_KERNEL_LAUNCHER(T, 256); \
break; \
/* case 64: */ \
/* CALL_KERNEL_LAUNCHER(T, 64); */ \
/* break; */ \
/* case 128: */ \
/* CALL_KERNEL_LAUNCHER(T, 128); */ \
/* break; */ \
/* case 256: */ \
/* CALL_KERNEL_LAUNCHER(T, 256); */ \
/* break; */ \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
@@ -455,8 +459,9 @@ void single_query_cached_kv_attention(
torch::Tensor& context_lens, // [num_seqs]
int block_size,
int max_context_len) {
// TODO(woosuk): Support FP32.
if (query.dtype() == at::ScalarType::Half) {
if (query.dtype() == at::ScalarType::Float) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);