[BUGFIX] Fix test_mla_backends.py. Scale MLA projection weights to prevent numerical instability (#32529)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -504,6 +504,14 @@ def test_backend_correctness(
|
||||
W_UV = torch.randn(
|
||||
kv_lora_rank, num_q_heads, v_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Scale weights to produce realistic magnitude outputs.
|
||||
# Without scaling, projection output has std ~sqrt(kv_lora_rank) ≈ 22.6,
|
||||
# causing extreme attention scores and numerical instability in LSE merging.
|
||||
weight_scale = 1.0 / (kv_lora_rank**0.5)
|
||||
W_UK = W_UK * weight_scale
|
||||
W_UV = W_UV * weight_scale
|
||||
|
||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
|
||||
Reference in New Issue
Block a user