[Hybrid][torch.compile] Refactor mamba2 forward to avoid obscuring linear projections under custom op (#28587)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
tomeras91
2025-11-19 02:49:36 +02:00
committed by GitHub
parent 9912b8ccb8
commit 1395461f5f
7 changed files with 92 additions and 90 deletions

View File

@@ -87,8 +87,7 @@ class Mamba2DecoderLayer(nn.Module):
else:
hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output)
output = self.mixer(hidden_states)
return output, residual