[Kernel] Use fused rmsnorm for some models like qwen3 series (#17735)
Signed-off-by: evian <eviantai@u.nus.edu> Co-authored-by: evian <eviantai@u.nus.edu>
This commit is contained in:
@@ -139,8 +139,8 @@ class Olmo2Attention(nn.Module):
|
||||
if self.tp_size > 1:
|
||||
q = tensor_model_parallel_all_gather(q.contiguous())
|
||||
k = tensor_model_parallel_all_gather(k.contiguous())
|
||||
q = self.q_norm.forward_native(q)
|
||||
k = self.k_norm.forward_native(k)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if self.tp_size > 1:
|
||||
splitter = partial(split_tensor_along_last_dim,
|
||||
num_partitions=self.tp_size)
|
||||
|
||||
Reference in New Issue
Block a user