[Hybrid][torch.compile] Refactor mamba2 forward to avoid obscuring linear projections under custom op (#28587)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
tomeras91
2025-11-19 02:49:36 +02:00
committed by GitHub
parent 9912b8ccb8
commit 1395461f5f
7 changed files with 92 additions and 90 deletions

View File

@@ -567,11 +567,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states)
# Process through Mamba mixer
output = torch.empty_like(hidden_states)
self.mamba(
hidden_states,
output,
)
output = self.mamba(hidden_states)
# residual connection after mamba
hidden_states = residual + output