[fix] Remove trtllm ragged mla prefills (#36540)

Signed-off-by: Olya Kozlova <okozlova@nvidia.com>
This commit is contained in:
Olya Kozlova
2026-03-31 21:30:27 +02:00
committed by GitHub
parent b779eb3363
commit 598190aac3
8 changed files with 185 additions and 35 deletions

View File

@@ -20,7 +20,11 @@ def merge_attn_states_torch(
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS]
prefill_tokens_with_context: int | None = None,
):
# Apply prefill_tokens_with_context mask if needed
if prefill_tokens_with_context is None:
prefill_tokens_with_context = output.shape[0]
p_lse = prefix_lse
s_lse = suffix_lse
# inf -> -inf
@@ -28,6 +32,9 @@ def merge_attn_states_torch(
s_lse[s_lse == torch.inf] = -torch.inf
# max_lse [NUM_HEADS, NUM_TOKENS]
max_lse = torch.maximum(p_lse, s_lse)
mask = torch.ones((prefix_lse.shape[1], 1, 1), device=p_lse.device)
mask[prefill_tokens_with_context:].fill_(0)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
p_lse_exp = torch.exp(p_lse)
@@ -35,11 +42,16 @@ def merge_attn_states_torch(
out_se = p_lse_exp + s_lse_exp
if output_lse is not None:
output_lse = torch.log(out_se) + max_lse
output_lse[prefill_tokens_with_context:] = suffix_lse[
prefill_tokens_with_context:
]
p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
output = prefix_output * p_scale + suffix_output * s_scale
output.copy_(
prefix_output * p_scale * mask + suffix_output * (s_scale * mask + (1 - mask))
)
return output, output_lse
@@ -90,13 +102,18 @@ def generate_markdown_table():
)
@pytest.mark.parametrize("prefill_tokens_with_context", [None, 128])
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("output_dtype", DTYPES)
@torch.inference_mode()
def test_merge_attn_states(
num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype
prefill_tokens_with_context: int | None,
num_tokens: int,
num_query_heads: int,
head_size: int,
output_dtype: torch.dtype,
):
if not current_platform.is_cuda():
pytest.skip(
@@ -111,6 +128,7 @@ def test_merge_attn_states(
print(
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
f"prefill_tokens_with_context: {prefill_tokens_with_context}, "
f"Device: {current_platform.get_device_name()}"
)
@@ -164,6 +182,7 @@ def test_merge_attn_states(
suffix_output,
suffix_lse_torch,
output_lse_torch,
prefill_tokens_with_context,
)
torch.accelerator.synchronize()
@@ -176,6 +195,7 @@ def test_merge_attn_states(
suffix_output,
suffix_lse_torch,
output_lse_torch,
prefill_tokens_with_context,
)
end.record()
torch.accelerator.synchronize()
@@ -199,6 +219,7 @@ def test_merge_attn_states(
suffix_output,
suffix_lse,
output_lse_ref_triton,
prefill_tokens_with_context,
)
torch.accelerator.synchronize()
@@ -211,6 +232,7 @@ def test_merge_attn_states(
suffix_output,
suffix_lse,
output_lse_ref_triton,
prefill_tokens_with_context,
)
end.record()
torch.accelerator.synchronize()
@@ -231,6 +253,7 @@ def test_merge_attn_states(
suffix_output,
suffix_lse,
output_lse_cuda,
prefill_tokens_with_context,
)
torch.accelerator.synchronize()
@@ -243,6 +266,7 @@ def test_merge_attn_states(
suffix_output,
suffix_lse,
output_lse_cuda,
prefill_tokens_with_context,
)
end.record()
torch.accelerator.synchronize()