[BUGFIX] Fix test_mla_backends.py. Scale MLA projection weights to prevent numerical instability (#32529)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson
2026-01-19 22:49:29 +04:00
committed by GitHub
parent a0490be8f1
commit 0727cc9ecf

View File

@@ -504,6 +504,14 @@ def test_backend_correctness(
W_UV = torch.randn(
kv_lora_rank, num_q_heads, v_head_dim, dtype=dtype, device=device
)
# Scale weights to produce realistic magnitude outputs.
# Without scaling, projection output has std ~sqrt(kv_lora_rank) ≈ 22.6,
# causing extreme attention scores and numerical instability in LSE merging.
weight_scale = 1.0 / (kv_lora_rank**0.5)
W_UK = W_UK * weight_scale
W_UV = W_UV * weight_scale
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
for i, backend in enumerate(BACKENDS_TO_TEST):