[ROCm] AITER fused RoPE+KVCache (#33443)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com>
Co-authored-by: charlifu <charlifu@amd.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
This commit is contained in:
Rohan Potdar
2026-02-23 21:06:00 -06:00
committed by GitHub
parent 95642441d0
commit 2ff4e51152
19 changed files with 1211 additions and 83 deletions

View File

@@ -570,11 +570,11 @@ direct_register_custom_op(
def get_attention_context(
layer_name: str,
) -> tuple[Any, "Attention | MLAAttention", torch.Tensor]:
) -> tuple[Any, "Attention | MLAAttention", torch.Tensor, torch.Tensor]:
"""Extract attention context for a given layer.
This helper function extracts the attention metadata, attention layer
instance, and KV cache tensor for a specific layer.
instance, KV cache tensor, and slot mapping for a specific layer.
Args:
layer_name: The name/identifier of the attention layer.
@@ -585,6 +585,7 @@ def get_attention_context(
no metadata available
- attn_layer: The attention layer instance (Attention or MLAAttention)
- kv_cache: The KV cache tensor for current virtual engine
- slot_mapping: The slot mapping for this specific layer
Note: attn_metadata may be None, but attn_layer and kv_cache are always
extracted from the forward context.
@@ -593,9 +594,14 @@ def get_attention_context(
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
attn_layer = forward_context.no_compile_layers[layer_name]
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
return attn_metadata, attn_layer, kv_cache
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
layer_slot_mapping = slot_mapping.get(layer_name)
return attn_metadata, attn_layer, kv_cache, layer_slot_mapping
@maybe_transfer_kv_layer
@@ -605,7 +611,7 @@ def unified_attention(
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
return output
@@ -636,15 +642,7 @@ def unified_kv_cache_update(
Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
forward_context = get_forward_context()
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
layer_slot_mapping = slot_mapping.get(layer_name)
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
if layer_slot_mapping is not None:
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
@@ -691,7 +689,7 @@ def unified_attention_with_output(
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del kv_cache_dummy_dep
attn_metadata, self, kv_cache = get_attention_context(layer_name)
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
self.impl.forward(
self,

View File

@@ -40,8 +40,8 @@ def maybe_transfer_kv_layer(func: Callable) -> Callable:
layer_name: str = args[layer_name_index]
# Extract attention context (layer-specific metadata, layer, and kv_cache)
attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name)
# Extract attention context (metadata, layer, kv_cache, layer_slot_mapping)
attn_metadata, _, kv_cache, _ = get_attention_context(layer_name)
connector = get_kv_transfer_group()
if attn_metadata is None or not connector.has_connector_metadata():
return func(*args, **kwargs)

View File

@@ -828,7 +828,7 @@ def unified_mla_attention(
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
attn_metadata, layer, kv_cache = get_attention_context(layer_name)
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata)
return output
@@ -862,7 +862,7 @@ def unified_mla_attention_with_output(
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
attn_metadata, layer, kv_cache = get_attention_context(layer_name)
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
layer.forward_impl(
q,
kv_c_normed,