[Flashinfer] Support Flashinfer TRTLLM FP8-qkv BF16/FP16-out Attention Kernel (#23647)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
@@ -194,19 +194,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
FlashInferBackend.validate_head_size(self.head_dim)
|
||||
self.page_size = self.kv_cache_spec.block_size
|
||||
|
||||
self.enable_fusion = (
|
||||
self.compilation_config.pass_config.enable_attn_fusion)
|
||||
self.q_data_type = self.model_config.dtype
|
||||
self.cache_dtype = self.cache_config.cache_dtype
|
||||
if self.cache_dtype.startswith("fp8"):
|
||||
self.kv_cache_dtype = (
|
||||
FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.cache_dtype))
|
||||
# Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
|
||||
if self.enable_fusion:
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
else:
|
||||
assert self.kv_cache_spec.dtype == self.model_config.dtype
|
||||
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
|
||||
@@ -668,8 +664,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
# The attn+quant fusion happens when output_scale is provided.
|
||||
if output_scale is None:
|
||||
assert attn_metadata.q_data_type != FP8_DTYPE, \
|
||||
"Query can only be FP8 if output fusion happened."
|
||||
assert output_block_scale is None, "output_block_scale "\
|
||||
"is not supported when fusion has not happened"
|
||||
else:
|
||||
@@ -697,7 +691,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
elif output.dtype == FP4_DTYPE:
|
||||
self.o_sf_scale = layer._o_scale_float
|
||||
|
||||
# Insert FP8 quant for query
|
||||
# Insert FP8 quant for query
|
||||
if attn_metadata.q_data_type == FP8_DTYPE:
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
|
||||
Reference in New Issue
Block a user