output type conversion fix (#27159)
This commit is contained in:
@@ -134,10 +134,7 @@ def matmul_kernel_persistent(
|
|||||||
bias_ptrs = bias_ptr + offs_cn
|
bias_ptrs = bias_ptr + offs_cn
|
||||||
bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32)
|
bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32)
|
||||||
accumulator += bias
|
accumulator += bias
|
||||||
if c_ptr.dtype.element_ty == tl.float8e4nv:
|
c = accumulator.to(c_ptr.dtype.element_ty)
|
||||||
c = accumulator.to(tl.float8e4nv)
|
|
||||||
else:
|
|
||||||
c = accumulator.to(tl.float16)
|
|
||||||
tl.store(c_ptrs, c, mask=c_mask)
|
tl.store(c_ptrs, c, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user