[Kernel] Fuse FP8 output quantization into merge_attn_states (#36518)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
This commit is contained in:
Carl Y
2026-04-02 18:47:04 -07:00
committed by GitHub
parent 1f5ec2889c
commit 3bc2734dd0
8 changed files with 516 additions and 70 deletions

View File

@@ -73,7 +73,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor prefix_lse,"
" Tensor suffix_output,"
" Tensor suffix_lse,"
" int!? prefill_tokens_with_context) -> ()");
" int!? prefill_tokens_with_context,"
" Tensor? output_scale=None) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#ifndef USE_ROCM
ops.def(