Mamba V2 Test not Asserting Failures. (#21379)

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
Yu Chin Fabian Lim
2025-07-23 04:40:27 -04:00
committed by GitHub
parent accac82928
commit 32ec9e2f2a
2 changed files with 25 additions and 10 deletions

View File

@@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel(
gate_states[..., local_rank * N:(local_rank + 1) * N],
)
ref_output = mixer_single_gpu(hidden_states, gate_states)
torch.allclose(output,
ref_output[..., local_rank * N:(local_rank + 1) * N],
atol=1e-3,
rtol=1e-3)
torch.testing.assert_close(output,
ref_output[...,
local_rank * N:(local_rank + 1) * N],
atol=5e-3,
rtol=1e-3)