[Kernel] Use fused rmsnorm for some models like qwen3 series (#17735)

Signed-off-by: evian <eviantai@u.nus.edu>
Co-authored-by: evian <eviantai@u.nus.edu>
This commit is contained in:
Wanrui Dai
2025-05-07 14:10:02 +08:00
committed by GitHub
parent 1a45a61387
commit f80ae5bdcf
7 changed files with 19 additions and 15 deletions

View File

@@ -139,8 +139,8 @@ class Olmo2Attention(nn.Module):
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm.forward_native(q)
k = self.k_norm.forward_native(k)
q = self.q_norm(q)
k = self.k_norm(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)