[Bugfix] fused_experts_impl wrong compute type for float32 (#11921)
Signed-off-by: shaochangxu.scx <shaochangxu.scx@antgroup.com> Co-authored-by: shaochangxu.scx <shaochangxu.scx@antgroup.com>
This commit is contained in:
@@ -701,8 +701,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
compute_type = (tl.bfloat16
|
||||
if hidden_states.dtype == torch.bfloat16 else tl.float16)
|
||||
if hidden_states.dtype == torch.bfloat16:
|
||||
compute_type = tl.bfloat16
|
||||
elif hidden_states.dtype == torch.float16:
|
||||
compute_type = tl.float16
|
||||
elif hidden_states.dtype == torch.float32:
|
||||
compute_type = tl.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
if inplace:
|
||||
out_hidden_states = hidden_states
|
||||
|
||||
Reference in New Issue
Block a user