[fix] Remove trtllm ragged mla prefills (#36540)
Signed-off-by: Olya Kozlova <okozlova@nvidia.com>
This commit is contained in:
@@ -20,7 +20,11 @@ def merge_attn_states_torch(
|
||||
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
|
||||
output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS]
|
||||
prefill_tokens_with_context: int | None = None,
|
||||
):
|
||||
# Apply prefill_tokens_with_context mask if needed
|
||||
if prefill_tokens_with_context is None:
|
||||
prefill_tokens_with_context = output.shape[0]
|
||||
p_lse = prefix_lse
|
||||
s_lse = suffix_lse
|
||||
# inf -> -inf
|
||||
@@ -28,6 +32,9 @@ def merge_attn_states_torch(
|
||||
s_lse[s_lse == torch.inf] = -torch.inf
|
||||
# max_lse [NUM_HEADS, NUM_TOKENS]
|
||||
max_lse = torch.maximum(p_lse, s_lse)
|
||||
|
||||
mask = torch.ones((prefix_lse.shape[1], 1, 1), device=p_lse.device)
|
||||
mask[prefill_tokens_with_context:].fill_(0)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
p_lse_exp = torch.exp(p_lse)
|
||||
@@ -35,11 +42,16 @@ def merge_attn_states_torch(
|
||||
out_se = p_lse_exp + s_lse_exp
|
||||
if output_lse is not None:
|
||||
output_lse = torch.log(out_se) + max_lse
|
||||
output_lse[prefill_tokens_with_context:] = suffix_lse[
|
||||
prefill_tokens_with_context:
|
||||
]
|
||||
p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
|
||||
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
|
||||
p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
output = prefix_output * p_scale + suffix_output * s_scale
|
||||
output.copy_(
|
||||
prefix_output * p_scale * mask + suffix_output * (s_scale * mask + (1 - mask))
|
||||
)
|
||||
return output, output_lse
|
||||
|
||||
|
||||
@@ -90,13 +102,18 @@ def generate_markdown_table():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prefill_tokens_with_context", [None, 128])
|
||||
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
|
||||
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("output_dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_merge_attn_states(
|
||||
num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype
|
||||
prefill_tokens_with_context: int | None,
|
||||
num_tokens: int,
|
||||
num_query_heads: int,
|
||||
head_size: int,
|
||||
output_dtype: torch.dtype,
|
||||
):
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(
|
||||
@@ -111,6 +128,7 @@ def test_merge_attn_states(
|
||||
print(
|
||||
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
|
||||
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
|
||||
f"prefill_tokens_with_context: {prefill_tokens_with_context}, "
|
||||
f"Device: {current_platform.get_device_name()}"
|
||||
)
|
||||
|
||||
@@ -164,6 +182,7 @@ def test_merge_attn_states(
|
||||
suffix_output,
|
||||
suffix_lse_torch,
|
||||
output_lse_torch,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@@ -176,6 +195,7 @@ def test_merge_attn_states(
|
||||
suffix_output,
|
||||
suffix_lse_torch,
|
||||
output_lse_torch,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
end.record()
|
||||
torch.accelerator.synchronize()
|
||||
@@ -199,6 +219,7 @@ def test_merge_attn_states(
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_ref_triton,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@@ -211,6 +232,7 @@ def test_merge_attn_states(
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_ref_triton,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
end.record()
|
||||
torch.accelerator.synchronize()
|
||||
@@ -231,6 +253,7 @@ def test_merge_attn_states(
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_cuda,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@@ -243,6 +266,7 @@ def test_merge_attn_states(
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_cuda,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
end.record()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
Reference in New Issue
Block a user