[Kernel][Backend][Model] Blocksparse flash attention kernel and Phi-3-Small model (#4799)
Co-authored-by: beagleski <yunanzhang@microsoft.com> Co-authored-by: bapatra <bapatra@microsoft.com> Co-authored-by: Barun Patra <codedecde@users.noreply.github.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
@@ -415,14 +415,17 @@ void paged_attention_v1_impl_launcher(
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
||||
int block_size, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, float kv_scale) {
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||
[&] {
|
||||
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
|
||||
@@ -726,16 +729,18 @@ void paged_attention_v2_impl_launcher(
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits, torch::Tensor& tmp_out,
|
||||
torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads,
|
||||
float scale, torch::Tensor& block_tables,
|
||||
torch::Tensor& seq_lens, int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, float kv_scale) {
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||
[&] {
|
||||
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
|
||||
|
||||
Reference in New Issue
Block a user