[Attention] Clarify comment explaining attn_logits +1 dimension (#33427)
Signed-off-by: Francesco Fusco <ffu@zurich.ibm.com>
This commit is contained in:
@@ -143,8 +143,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
B,
|
||||
q_num_heads,
|
||||
num_kv_splits,
|
||||
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
||||
# just mirror that
|
||||
# NOTE: the +1 stores the LogSumExp (LSE) that the stage2
|
||||
# kernel uses to merge partial attention outputs across splits.
|
||||
self.kv_lora_rank + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
|
||||
Reference in New Issue
Block a user