Refactor dense FP8 tensor/channel/block utils and add CT FP8 block (#21404)
This commit is contained in:
@@ -17,6 +17,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
group_broadcast)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
@@ -794,3 +797,220 @@ def requant_weight_ue8m0_inplace(
|
||||
# Write back the results in-place.
|
||||
w_q.copy_(w_requant)
|
||||
s_old.copy_(s_requant)
|
||||
|
||||
|
||||
def check_aiter_fp8_linear_support() -> bool:
|
||||
"""AITER is only supported on ROCm and only for FP8_FNUZ
|
||||
and at the moment are MI300 series"""
|
||||
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
and current_platform.is_fp8_fnuz())
|
||||
|
||||
|
||||
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
can benefit from tensors located far enough from one another in memory"""
|
||||
if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
|
||||
and weight.stride(-1) == 1
|
||||
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
|
||||
num_pad = 256 // weight.element_size()
|
||||
import torch.nn.functional as F
|
||||
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
||||
torch.cuda.empty_cache()
|
||||
return weight
|
||||
|
||||
|
||||
def validate_fp8_block_shape(layer: torch.nn.Module, input_size: int,
|
||||
output_size: int, input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
block_size: list[int]) -> None:
|
||||
"""Validate block quantization shapes for tensor parallelism."""
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
|
||||
tp_size = getattr(layer, "tp_size", get_tensor_model_parallel_world_size())
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
# Required by row parallel
|
||||
if (tp_size > 1 and input_size // input_size_per_partition == tp_size
|
||||
and input_size_per_partition % block_k != 0):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = {input_size_per_partition} "
|
||||
f"is not divisible by weight quantization block_k = {block_k}.")
|
||||
|
||||
# Required by column parallel or enabling merged weights
|
||||
is_tp_split = (tp_size > 1
|
||||
and output_size // sum(output_partition_sizes) == tp_size)
|
||||
is_merged_gemm = len(output_partition_sizes) > 1
|
||||
if is_tp_split or is_merged_gemm:
|
||||
sizes_to_check = output_partition_sizes
|
||||
if not is_tp_split and is_merged_gemm:
|
||||
# In case of merged matrices, we allow the last
|
||||
# matrix to not be a multiple of block size
|
||||
sizes_to_check = output_partition_sizes[:-1]
|
||||
for output_partition_size in sizes_to_check:
|
||||
if output_partition_size % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_partition_size = "
|
||||
f"{output_partition_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
|
||||
|
||||
def create_fp8_weight_parameter(
|
||||
output_size_per_partition: int, input_size_per_partition: int,
|
||||
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
|
||||
"""Create FP8 weight parameter."""
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
|
||||
return ModelWeightParameter(data=torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
|
||||
def create_fp8_scale_parameter(
|
||||
parameter_type: torch.nn.Parameter, output_partition_sizes: list[int],
|
||||
input_size_per_partition: int, block_size: Optional[list[int]],
|
||||
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
|
||||
"""Create scale parameter based on quantization strategy."""
|
||||
if parameter_type == ChannelQuantScaleParameter:
|
||||
scale = parameter_type(data=torch.empty(
|
||||
(sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
elif parameter_type == BlockQuantScaleParameter:
|
||||
assert block_size is not None
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
scale = parameter_type(
|
||||
data=torch.empty(
|
||||
(output_size_per_partition + block_n - 1) // block_n,
|
||||
(input_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
elif parameter_type == PerTensorScaleParameter:
|
||||
scale = parameter_type(data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
raise ValueError(f"Unknown parameter type: {parameter_type}")
|
||||
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
return scale
|
||||
|
||||
|
||||
def create_fp8_input_scale(
|
||||
output_partition_sizes: list[int],
|
||||
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
|
||||
"""Create input scale parameter for static activation quantization."""
|
||||
from vllm.model_executor.parameter import PerTensorScaleParameter
|
||||
|
||||
scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
return scale
|
||||
|
||||
|
||||
def process_fp8_weight_tensor_strategy(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
logical_widths: list[int],
|
||||
input_scale: Optional[torch.Tensor] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Process weights for tensor-wise quantization strategy."""
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale, input_scale=input_scale)
|
||||
|
||||
# Requantize with max scale
|
||||
weight_scale, weight = requantize_with_max_scale(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
logical_widths=logical_widths,
|
||||
)
|
||||
|
||||
weight = _maybe_pad_fp8_weight(weight)
|
||||
return weight, weight_scale, input_scale
|
||||
|
||||
|
||||
def process_fp8_weight_channel_strategy(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Process weights for channel-wise quantization strategy."""
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale, input_scale=input_scale)
|
||||
|
||||
return weight, weight_scale, input_scale
|
||||
|
||||
|
||||
def process_fp8_weight_block_strategy(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Process weights for block-wise quantization strategy."""
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale)
|
||||
|
||||
weight = _maybe_pad_fp8_weight(weight)
|
||||
return weight, weight_scale
|
||||
|
||||
|
||||
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
|
||||
cutlass_block_fp8_supported: bool):
|
||||
assert layer.weight_block_size is not None
|
||||
|
||||
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
|
||||
should_use_deepgemm_for_fp8_linear)
|
||||
|
||||
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
|
||||
# requantize the weight and input to the specific scale
|
||||
# at the same time.
|
||||
if is_deep_gemm_e8m0_used():
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(layer.weight.data,
|
||||
layer.weight_scale.data, block_sz)
|
||||
# SM90 Block FP8 CUTLASS requires row-major weight scales
|
||||
elif (current_platform.is_device_capability(90)
|
||||
and cutlass_block_fp8_supported
|
||||
and not should_use_deepgemm_for_fp8_linear(torch.bfloat16,
|
||||
layer.weight)):
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data.T.contiguous(), requires_grad=False)
|
||||
|
||||
|
||||
def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
cutlass_block_fp8_supported: bool,
|
||||
use_aiter_and_is_supported: bool) -> torch.Tensor:
|
||||
"""Apply block-wise FP8 linear operation."""
|
||||
assert layer.weight_block_size is not None
|
||||
|
||||
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
|
||||
input=input,
|
||||
weight=layer.weight,
|
||||
block_size=layer.weight_block_size,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=use_aiter_and_is_supported,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user