Support various block sizes & Change default block size to 16 (#38)

This commit is contained in:
Woosuk Kwon
2023-04-15 09:03:24 -07:00
committed by GitHub
parent 84eee24e20
commit 0f4b32199e
7 changed files with 594 additions and 611 deletions

View File

@@ -11,25 +11,9 @@ void single_query_cached_kv_attention(
int block_size,
int max_context_len);
void multi_query_cached_kv_attention(
torch::Tensor& cu_query_lens,
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"single_query_cached_kv_attention",
&single_query_cached_kv_attention,
"Compute the attention between an input query and the cached key/value tensors");
m.def(
"multi_query_cached_kv_attention",
&multi_query_cached_kv_attention,
"Compute the attention between multiple input queries and the cached key/value tensors");
}

File diff suppressed because it is too large Load Diff

View File

@@ -1074,6 +1074,21 @@ inline __device__ float sum(Float8_ v)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float dot(float a, float b)
{
return a * b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float dot(float2 a, float2 b)
{
float2 c = mul<float2, float2, float2>(a, b);
return c.x + c.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float dot(Float4_ a, Float4_ b)
{
float2 acc = mul<float2, float2, float2>(a.x, b.x);
@@ -1253,37 +1268,44 @@ inline __device__ float convert_to_float(uint4 u)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float cast_to_float(float u)
{
return u;
}
// inline __device__ float cast_to_float(float u)
// {
// return u;
// }
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 cast_to_float(float2 u)
{
return u;
}
// inline __device__ float2 cast_to_float(float2 u)
// {
// return u;
// }
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 cast_to_float(float4 u)
{
return u;
}
// inline __device__ float4 cast_to_float(float4 u)
// {
// return u;
// }
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ cast_to_float(Float4_ u)
{
return u;
}
// inline __device__ Float4_ cast_to_float(Float4_ u)
// {
// return u;
// }
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ cast_to_float(Float8_ u)
// inline __device__ Float8_ cast_to_float(Float8_ u)
// {
// return u;
// }
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float cast_to_float(uint16_t u)
{
return u;
return half_to_float(u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////