[BUGFIX] Fix test_mla_backends.py. Scale MLA projection weights to prevent numerical instability (#32529)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -504,6 +504,14 @@ def test_backend_correctness(
|
|||||||
W_UV = torch.randn(
|
W_UV = torch.randn(
|
||||||
kv_lora_rank, num_q_heads, v_head_dim, dtype=dtype, device=device
|
kv_lora_rank, num_q_heads, v_head_dim, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Scale weights to produce realistic magnitude outputs.
|
||||||
|
# Without scaling, projection output has std ~sqrt(kv_lora_rank) ≈ 22.6,
|
||||||
|
# causing extreme attention scores and numerical instability in LSE merging.
|
||||||
|
weight_scale = 1.0 / (kv_lora_rank**0.5)
|
||||||
|
W_UK = W_UK * weight_scale
|
||||||
|
W_UV = W_UV * weight_scale
|
||||||
|
|
||||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||||
|
|
||||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||||
|
|||||||
Reference in New Issue
Block a user