diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 354e83cd8..0744db0b0 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -504,6 +504,14 @@ def test_backend_correctness( W_UV = torch.randn( kv_lora_rank, num_q_heads, v_head_dim, dtype=dtype, device=device ) + + # Scale weights to produce realistic magnitude outputs. + # Without scaling, projection output has std ~sqrt(kv_lora_rank) ≈ 22.6, + # causing extreme attention scores and numerical instability in LSE merging. + weight_scale = 1.0 / (kv_lora_rank**0.5) + W_UK = W_UK * weight_scale + W_UV = W_UV * weight_scale + kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) for i, backend in enumerate(BACKENDS_TO_TEST):