[ROCm] Remove unnecessary fp8 roundtrip in gather cache NHD dequant (#39122)

Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
This commit is contained in:
Andrew Barnes
2026-04-09 03:12:22 -04:00
committed by GitHub
parent ed2f282bc8
commit 8a34c5087a

View File

@@ -112,10 +112,12 @@ if current_platform.is_rocm():
if DEQUANT:
k_scale = tl.load(k_scale_ptr)
v_scale = tl.load(v_scale_ptr)
k_dtype = k_reg.dtype
v_dtype = v_reg.dtype
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
k_reg = (k_reg.to(tl.float32) * k_scale).to(
key_ptr_offset.dtype.element_ty
)
v_reg = (v_reg.to(tl.float32) * v_scale).to(
value_ptr_offset.dtype.element_ty
)
tl.store(key_ptr_offset + col_offsets, k_reg)
tl.store(value_ptr_offset + col_offsets, v_reg)