Use aiter triton fused_add_rmsnorm_pad for gpt-oss (#30976)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
@@ -187,7 +187,7 @@ class MLPBlock(torch.nn.Module):
|
||||
)
|
||||
else:
|
||||
g = self.router(x)
|
||||
x = self.experts(hidden_states=x, router_logits=g)
|
||||
x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size]
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
|
||||
|
||||
Reference in New Issue
Block a user