[ROCm][Quantization] add fp8xfp8 attn support for rocm_aiter_unified_attn (#36927)

Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
Divakar Verma
2026-03-17 20:49:32 -04:00
committed by GitHub
parent 09e4576f65
commit e6c4797704

View File

@@ -125,6 +125,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
from aiter.ops.triton.unified_attention import unified_attention
self.unified_attention = unified_attention
self.supports_quant_query_input = True
def forward(
self,
@@ -190,12 +191,20 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key_cache, value_cache = kv_cache.unbind(0)
softmax_scale = self.scale
fp8_post_attn_v_rescale = False
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
# When Q is FP8, triton kernel skips K/V dequant (for fp8xfp8 matmul).
# Compensate by absorbing q_scale and k_scale into softmax_scale, and
# v_scale into output_scale (or post-multiplying if no fusion).
if query.dtype == self.fp8_dtype:
softmax_scale = self.scale * layer._q_scale_float * layer._k_scale_float
if output_scale is not None:
output_scale = output_scale / layer._v_scale_float
else:
fp8_post_attn_v_rescale = True
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
@@ -217,19 +226,22 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
q_descale=None, # q_scale absorbed into softmax_scale
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
output_scale=output_scale,
)
if fp8_post_attn_v_rescale:
output[:num_actual_tokens].mul_(layer._v_scale_float)
return output
def do_kv_cache_update(