Support various block sizes & Change default block size to 16 (#38)
This commit is contained in:
@@ -11,25 +11,9 @@ void single_query_cached_kv_attention(
|
||||
int block_size,
|
||||
int max_context_len);
|
||||
|
||||
void multi_query_cached_kv_attention(
|
||||
torch::Tensor& cu_query_lens,
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"single_query_cached_kv_attention",
|
||||
&single_query_cached_kv_attention,
|
||||
"Compute the attention between an input query and the cached key/value tensors");
|
||||
m.def(
|
||||
"multi_query_cached_kv_attention",
|
||||
&multi_query_cached_kv_attention,
|
||||
"Compute the attention between multiple input queries and the cached key/value tensors");
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1074,6 +1074,21 @@ inline __device__ float sum(Float8_ v)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float dot(float a, float b)
|
||||
{
|
||||
return a * b;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float dot(float2 a, float2 b)
|
||||
{
|
||||
float2 c = mul<float2, float2, float2>(a, b);
|
||||
return c.x + c.y;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float dot(Float4_ a, Float4_ b)
|
||||
{
|
||||
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
||||
@@ -1253,37 +1268,44 @@ inline __device__ float convert_to_float(uint4 u)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float cast_to_float(float u)
|
||||
{
|
||||
return u;
|
||||
}
|
||||
// inline __device__ float cast_to_float(float u)
|
||||
// {
|
||||
// return u;
|
||||
// }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float2 cast_to_float(float2 u)
|
||||
{
|
||||
return u;
|
||||
}
|
||||
// inline __device__ float2 cast_to_float(float2 u)
|
||||
// {
|
||||
// return u;
|
||||
// }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float4 cast_to_float(float4 u)
|
||||
{
|
||||
return u;
|
||||
}
|
||||
// inline __device__ float4 cast_to_float(float4 u)
|
||||
// {
|
||||
// return u;
|
||||
// }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ Float4_ cast_to_float(Float4_ u)
|
||||
{
|
||||
return u;
|
||||
}
|
||||
// inline __device__ Float4_ cast_to_float(Float4_ u)
|
||||
// {
|
||||
// return u;
|
||||
// }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ Float8_ cast_to_float(Float8_ u)
|
||||
// inline __device__ Float8_ cast_to_float(Float8_ u)
|
||||
// {
|
||||
// return u;
|
||||
// }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float cast_to_float(uint16_t u)
|
||||
{
|
||||
return u;
|
||||
return half_to_float(u);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user