[ROCm] [Bugfix] Fix fused_qknorm_rope_kernel rocm compatibility (#28500)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -175,7 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"float epsilon) -> ()");
|
||||
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Function for fused QK Norm and RoPE
|
||||
ops.def(
|
||||
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
|
||||
@@ -183,7 +182,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
|
||||
"bool is_neox, Tensor position_ids) -> ()");
|
||||
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
|
||||
#endif
|
||||
|
||||
// Apply repetition penalties to logits in-place
|
||||
ops.def(
|
||||
|
||||
Reference in New Issue
Block a user