[ROCm][Quantization] add fp8xfp8 attn support for rocm_aiter_unified_attn (#36927)
Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
@@ -125,6 +125,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
from aiter.ops.triton.unified_attention import unified_attention
|
||||
|
||||
self.unified_attention = unified_attention
|
||||
self.supports_quant_query_input = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -190,12 +191,20 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
softmax_scale = self.scale
|
||||
fp8_post_attn_v_rescale = False
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
# When Q is FP8, triton kernel skips K/V dequant (for fp8xfp8 matmul).
|
||||
# Compensate by absorbing q_scale and k_scale into softmax_scale, and
|
||||
# v_scale into output_scale (or post-multiplying if no fusion).
|
||||
if query.dtype == self.fp8_dtype:
|
||||
softmax_scale = self.scale * layer._q_scale_float * layer._k_scale_float
|
||||
if output_scale is not None:
|
||||
output_scale = output_scale / layer._v_scale_float
|
||||
else:
|
||||
fp8_post_attn_v_rescale = True
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
@@ -217,19 +226,22 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
q_descale=None, # q_scale absorbed into softmax_scale
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
if fp8_post_attn_v_rescale:
|
||||
output[:num_actual_tokens].mul_(layer._v_scale_float)
|
||||
|
||||
return output
|
||||
|
||||
def do_kv_cache_update(
|
||||
|
||||
Reference in New Issue
Block a user