Support FP32 (#141)
This commit is contained in:
@@ -370,9 +370,11 @@ void single_query_cached_kv_attention_launcher(
|
||||
dim3 block(NUM_THREADS);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
switch (head_size) {
|
||||
case 32:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
|
||||
// 32, 160, 192, 256.
|
||||
// case 32:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
case 64:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
@@ -385,15 +387,15 @@ void single_query_cached_kv_attention_launcher(
|
||||
case 128:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
case 160:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
case 192:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
// case 160:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
// case 192:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
// case 256:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
@@ -411,17 +413,19 @@ void single_query_cached_kv_attention_launcher(
|
||||
context_lens, \
|
||||
max_context_len);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 1: \
|
||||
CALL_KERNEL_LAUNCHER(T, 1); \
|
||||
break; \
|
||||
case 2: \
|
||||
CALL_KERNEL_LAUNCHER(T, 2); \
|
||||
break; \
|
||||
case 4: \
|
||||
CALL_KERNEL_LAUNCHER(T, 4); \
|
||||
break; \
|
||||
/* case 1: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 1); */ \
|
||||
/* break; */ \
|
||||
/* case 2: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 2); */ \
|
||||
/* break; */ \
|
||||
/* case 4: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 4); */ \
|
||||
/* break; */ \
|
||||
case 8: \
|
||||
CALL_KERNEL_LAUNCHER(T, 8); \
|
||||
break; \
|
||||
@@ -431,15 +435,15 @@ void single_query_cached_kv_attention_launcher(
|
||||
case 32: \
|
||||
CALL_KERNEL_LAUNCHER(T, 32); \
|
||||
break; \
|
||||
case 64: \
|
||||
CALL_KERNEL_LAUNCHER(T, 64); \
|
||||
break; \
|
||||
case 128: \
|
||||
CALL_KERNEL_LAUNCHER(T, 128); \
|
||||
break; \
|
||||
case 256: \
|
||||
CALL_KERNEL_LAUNCHER(T, 256); \
|
||||
break; \
|
||||
/* case 64: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 64); */ \
|
||||
/* break; */ \
|
||||
/* case 128: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 128); */ \
|
||||
/* break; */ \
|
||||
/* case 256: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 256); */ \
|
||||
/* break; */ \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
@@ -455,8 +459,9 @@ void single_query_cached_kv_attention(
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len) {
|
||||
// TODO(woosuk): Support FP32.
|
||||
if (query.dtype() == at::ScalarType::Half) {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||
|
||||
Reference in New Issue
Block a user