[ROCm] AITER fused RoPE+KVCache (#33443)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com> Signed-off-by: charlifu <charlifu@amd.com> Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Co-authored-by: charlifu <charlifu@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
This commit is contained in:
@@ -570,11 +570,11 @@ direct_register_custom_op(
|
||||
|
||||
def get_attention_context(
|
||||
layer_name: str,
|
||||
) -> tuple[Any, "Attention | MLAAttention", torch.Tensor]:
|
||||
) -> tuple[Any, "Attention | MLAAttention", torch.Tensor, torch.Tensor]:
|
||||
"""Extract attention context for a given layer.
|
||||
|
||||
This helper function extracts the attention metadata, attention layer
|
||||
instance, and KV cache tensor for a specific layer.
|
||||
instance, KV cache tensor, and slot mapping for a specific layer.
|
||||
|
||||
Args:
|
||||
layer_name: The name/identifier of the attention layer.
|
||||
@@ -585,6 +585,7 @@ def get_attention_context(
|
||||
no metadata available
|
||||
- attn_layer: The attention layer instance (Attention or MLAAttention)
|
||||
- kv_cache: The KV cache tensor for current virtual engine
|
||||
- slot_mapping: The slot mapping for this specific layer
|
||||
|
||||
Note: attn_metadata may be None, but attn_layer and kv_cache are always
|
||||
extracted from the forward context.
|
||||
@@ -593,9 +594,14 @@ def get_attention_context(
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
return attn_metadata, attn_layer, kv_cache
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
)
|
||||
layer_slot_mapping = slot_mapping.get(layer_name)
|
||||
return attn_metadata, attn_layer, kv_cache, layer_slot_mapping
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
@@ -605,7 +611,7 @@ def unified_attention(
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
|
||||
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||
|
||||
return output
|
||||
@@ -636,15 +642,7 @@ def unified_kv_cache_update(
|
||||
Returns a dummy that is passed to unified_attention to signal a side effect and
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
)
|
||||
layer_slot_mapping = slot_mapping.get(layer_name)
|
||||
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
|
||||
if layer_slot_mapping is not None:
|
||||
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
|
||||
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
|
||||
@@ -691,7 +689,7 @@ def unified_attention_with_output(
|
||||
# that ensures torch.compile preserves ordering between KV cache update and
|
||||
# attention forward.
|
||||
del kv_cache_dummy_dep
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
|
||||
|
||||
self.impl.forward(
|
||||
self,
|
||||
|
||||
@@ -40,8 +40,8 @@ def maybe_transfer_kv_layer(func: Callable) -> Callable:
|
||||
|
||||
layer_name: str = args[layer_name_index]
|
||||
|
||||
# Extract attention context (layer-specific metadata, layer, and kv_cache)
|
||||
attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name)
|
||||
# Extract attention context (metadata, layer, kv_cache, layer_slot_mapping)
|
||||
attn_metadata, _, kv_cache, _ = get_attention_context(layer_name)
|
||||
connector = get_kv_transfer_group()
|
||||
if attn_metadata is None or not connector.has_connector_metadata():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@@ -828,7 +828,7 @@ def unified_mla_attention(
|
||||
k_pe: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
attn_metadata, layer, kv_cache = get_attention_context(layer_name)
|
||||
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
|
||||
output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata)
|
||||
|
||||
return output
|
||||
@@ -862,7 +862,7 @@ def unified_mla_attention_with_output(
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
attn_metadata, layer, kv_cache = get_attention_context(layer_name)
|
||||
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
|
||||
layer.forward_impl(
|
||||
q,
|
||||
kv_c_normed,
|
||||
|
||||
Reference in New Issue
Block a user