[Kernel] Use fused rmsnorm for some models like qwen3 series (#17735)
Signed-off-by: evian <eviantai@u.nus.edu> Co-authored-by: evian <eviantai@u.nus.edu>
This commit is contained in:
@@ -133,11 +133,11 @@ class Qwen3Attention(nn.Module):
|
||||
# Add qk-norm
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_by_head = self.q_norm.forward_native(q_by_head)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_by_head = self.k_norm.forward_native(k_by_head)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
Reference in New Issue
Block a user