[FP8] Refactor apply_fp8_linear and apply_fp8_linear_generic into an object (#14390)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-03-07 00:20:16 -05:00
committed by GitHub
parent dae6896977
commit e1744502c2
11 changed files with 257 additions and 231 deletions

View File

@@ -219,7 +219,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
@@ -640,6 +640,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.vllm_flash_attn_version = get_flash_attn_version()
self.fp8_linear_generic = Fp8LinearGenericOp()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
@@ -653,7 +654,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_UV_O):
output_parallel = apply_fp8_linear_generic(
output_parallel = self.fp8_linear_generic.apply(
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape)
@@ -673,7 +674,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_Q_UK):
return apply_fp8_linear_generic(
return self.fp8_linear_generic.apply(
x, self.W_Q_UK, self.W_Q_UK_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape).view(