[FP8] Refactor apply_fp8_linear and apply_fp8_linear_generic into an object (#14390)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
@@ -219,7 +219,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW8A8Fp8)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
|
||||
Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_quantize)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
@@ -640,6 +640,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
self.fp8_linear_generic = Fp8LinearGenericOp()
|
||||
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
@@ -653,7 +654,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
if is_fp8(self.W_UV_O):
|
||||
output_parallel = apply_fp8_linear_generic(
|
||||
output_parallel = self.fp8_linear_generic.apply(
|
||||
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
|
||||
self.reqaunt_input_group_shape,
|
||||
self.reqaunt_weight_group_shape)
|
||||
@@ -673,7 +674,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
if is_fp8(self.W_Q_UK):
|
||||
return apply_fp8_linear_generic(
|
||||
return self.fp8_linear_generic.apply(
|
||||
x, self.W_Q_UK, self.W_Q_UK_scales,
|
||||
self.reqaunt_input_group_shape,
|
||||
self.reqaunt_weight_group_shape).view(
|
||||
|
||||
Reference in New Issue
Block a user