[Attention] MLA get rid of materialization (#14770)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -13,10 +13,9 @@ import triton.language as tl
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
_normalize_quant_group_shape, scaled_dequantize)
|
||||
scaled_dequantize)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED, Fp8LinearOp, cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported)
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@@ -101,60 +100,6 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
# Unify the interface between `apply_w8a8_block_fp8_linear` and
|
||||
# `apply_fp8_linear`
|
||||
# NOTE(lucas): this is quite messy, we should think through this more formally
|
||||
# TODO(luka): unify this better
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
class Fp8LinearGenericOp:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
||||
cutlass_block_fp8_supported: bool = cutlass_block_fp8_supported(),
|
||||
):
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
cutlass_fp8_supported=cutlass_fp8_supported)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_group_shape: Tuple[int, int],
|
||||
weight_group_shape: Tuple[int, int],
|
||||
input_scale: Optional[torch.Tensor] = None, # static scale if one
|
||||
) -> torch.Tensor:
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input = input.view(-1, input.shape[-1])
|
||||
|
||||
weight_group_shape = _normalize_quant_group_shape( \
|
||||
weight, weight_group_shape)
|
||||
input_group_shape = _normalize_quant_group_shape(
|
||||
input, input_group_shape)
|
||||
|
||||
def is_dim_blocked(dim, shape, group_shape):
|
||||
return group_shape < shape[dim] and group_shape > 1
|
||||
|
||||
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
|
||||
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
|
||||
input_group_shape == (1, weight_group_shape[1]):
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
input,
|
||||
weight,
|
||||
list(weight_group_shape),
|
||||
weight_scale,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported)
|
||||
else:
|
||||
# Despite having linear in the name it doesn't conform to
|
||||
# `torch.nn.functional.linear` which is defined as
|
||||
# `input @ weight.T` so we explicitly transpose the weight matrix
|
||||
return self.fp8_linear.apply(input, weight.T, weight_scale.T,
|
||||
use_per_token_if_dynamic=\
|
||||
(input_group_shape == (1, input.shape[1])))
|
||||
|
||||
|
||||
def input_to_float8(
|
||||
x: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None
|
||||
|
||||
Reference in New Issue
Block a user