[ROCm] Remove unnecessary fp8 roundtrip in gather cache NHD dequant (#39122)
Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
This commit is contained in:
@@ -112,10 +112,12 @@ if current_platform.is_rocm():
|
||||
if DEQUANT:
|
||||
k_scale = tl.load(k_scale_ptr)
|
||||
v_scale = tl.load(v_scale_ptr)
|
||||
k_dtype = k_reg.dtype
|
||||
v_dtype = v_reg.dtype
|
||||
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
|
||||
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
|
||||
k_reg = (k_reg.to(tl.float32) * k_scale).to(
|
||||
key_ptr_offset.dtype.element_ty
|
||||
)
|
||||
v_reg = (v_reg.to(tl.float32) * v_scale).to(
|
||||
value_ptr_offset.dtype.element_ty
|
||||
)
|
||||
tl.store(key_ptr_offset + col_offsets, k_reg)
|
||||
tl.store(value_ptr_offset + col_offsets, v_reg)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user