Mamba V2 Test not Asserting Failures. (#21379)
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
committed by
GitHub
parent
accac82928
commit
32ec9e2f2a
@@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel(
|
||||
gate_states[..., local_rank * N:(local_rank + 1) * N],
|
||||
)
|
||||
ref_output = mixer_single_gpu(hidden_states, gate_states)
|
||||
torch.allclose(output,
|
||||
ref_output[..., local_rank * N:(local_rank + 1) * N],
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
torch.testing.assert_close(output,
|
||||
ref_output[...,
|
||||
local_rank * N:(local_rank + 1) * N],
|
||||
atol=5e-3,
|
||||
rtol=1e-3)
|
||||
|
||||
Reference in New Issue
Block a user