Feature/mla tests (#23195)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -24,7 +24,7 @@ Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
||||
|
||||
Deepseek's MLA attention works the following way:
|
||||
* Use a single latent vector to represent the per-token entry of the KV cache.
|
||||
* Use a single latent vector to represent the per-token entry of the KV cache.
|
||||
* For decode (i.e. the memory friendly approach) the attention "simulates" a
|
||||
multi-head attention, while the compute is similar to multi-query attention.
|
||||
|
||||
@@ -82,7 +82,7 @@ spda_o = scaled_dot_product_attention(
|
||||
torch.cat([q_nope, q_pe], dim=-1),
|
||||
torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
|
||||
v
|
||||
)
|
||||
)
|
||||
return spda_o @ W_O
|
||||
|
||||
NOTE: in the actual code,
|
||||
@@ -120,20 +120,20 @@ return o.view(-1, N * V) @ self.num_heads @ W_O
|
||||
|
||||
## Chunked Prefill
|
||||
|
||||
For chunked prefill we want to use the compute friendly algorithm. We are
|
||||
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
|
||||
For chunked prefill we want to use the compute friendly algorithm. We are
|
||||
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
|
||||
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
|
||||
|
||||
However, the compute-friendly approach can potentially run out of memory if Skv
|
||||
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
|
||||
|
||||
To mitigate this, we chunk the computation of attention with respect to the
|
||||
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
|
||||
To mitigate this, we chunk the computation of attention with respect to the
|
||||
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
|
||||
fixed workspace size.
|
||||
|
||||
The chunked prefill approach is as follows:
|
||||
|
||||
MCC Max chunk of context to process per iter, computed dynamically,
|
||||
MCC Max chunk of context to process per iter, computed dynamically,
|
||||
used to bound the memory usage
|
||||
|
||||
q_c = h_t @ W_DQ
|
||||
@@ -155,7 +155,7 @@ curr_o, curr_lse = scaled_dot_product_attention(
|
||||
new_v,
|
||||
casual=True,
|
||||
return_softmax_lse=True
|
||||
)
|
||||
)
|
||||
|
||||
// Compute attention with the already existing context
|
||||
for chunk_idx in range(cdiv(C, MCC)):
|
||||
|
||||
Reference in New Issue
Block a user