[Bugfix] fused_experts_impl wrong compute type for float32 (#11921)
Signed-off-by: shaochangxu.scx <shaochangxu.scx@antgroup.com> Co-authored-by: shaochangxu.scx <shaochangxu.scx@antgroup.com>
This commit is contained in:
@@ -701,8 +701,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype)
|
dtype=hidden_states.dtype)
|
||||||
|
|
||||||
compute_type = (tl.bfloat16
|
if hidden_states.dtype == torch.bfloat16:
|
||||||
if hidden_states.dtype == torch.bfloat16 else tl.float16)
|
compute_type = tl.bfloat16
|
||||||
|
elif hidden_states.dtype == torch.float16:
|
||||||
|
compute_type = tl.float16
|
||||||
|
elif hidden_states.dtype == torch.float32:
|
||||||
|
compute_type = tl.float32
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||||
|
|
||||||
if inplace:
|
if inplace:
|
||||||
out_hidden_states = hidden_states
|
out_hidden_states = hidden_states
|
||||||
|
|||||||
Reference in New Issue
Block a user