[Hybrid][torch.compile] Refactor mamba2 forward to avoid obscuring linear projections under custom op (#28587)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
tomeras91
2025-11-19 02:49:36 +02:00
committed by GitHub
parent 9912b8ccb8
commit 1395461f5f
7 changed files with 92 additions and 90 deletions

View File

@@ -138,8 +138,7 @@ class BambaMixerDecoderLayer(nn.Module):
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output)
output = self.mamba(hidden_states)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(output, residual)
hidden_states = self.feed_forward(hidden_states)