[ROCm] AITER fused RoPE+KVCache (#33443)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com> Signed-off-by: charlifu <charlifu@amd.com> Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Co-authored-by: charlifu <charlifu@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
This commit is contained in:
@@ -828,7 +828,7 @@ def unified_mla_attention(
|
||||
k_pe: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
attn_metadata, layer, kv_cache = get_attention_context(layer_name)
|
||||
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
|
||||
output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata)
|
||||
|
||||
return output
|
||||
@@ -862,7 +862,7 @@ def unified_mla_attention_with_output(
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
attn_metadata, layer, kv_cache = get_attention_context(layer_name)
|
||||
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
|
||||
layer.forward_impl(
|
||||
q,
|
||||
kv_c_normed,
|
||||
|
||||
Reference in New Issue
Block a user