[Bug] Fix DeepGEMM Attention Test (#26423)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -82,8 +82,7 @@ def _ref_fp8_mqa_logits(
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
||||
)
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum("mhd,and->hmn", q, k)
|
||||
score = torch.einsum("mhd,nd->hmn", q, k)
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user