[Kernel][Perf] fuse QK Norm and RoPE into one cuda kernel for Qwen Model (#27165)
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
@@ -175,6 +175,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"float epsilon) -> ()");
|
||||
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Function for fused QK Norm and RoPE
|
||||
ops.def(
|
||||
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
|
||||
"int num_heads_k, int num_heads_v, int head_dim, float eps, "
|
||||
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
|
||||
"bool is_neox, Tensor position_ids) -> ()");
|
||||
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
|
||||
#endif
|
||||
|
||||
// Apply repetition penalties to logits in-place
|
||||
ops.def(
|
||||
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
||||
|
||||
Reference in New Issue
Block a user