[Refactor] Make FP8 Linear Ops use kernel abstraction (#27814)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm
2026-01-20 14:48:20 +08:00
committed by GitHub
parent e9c83cdc51
commit 148117ea2e
30 changed files with 1467 additions and 1038 deletions

View File

@@ -0,0 +1,5 @@
Qwen2.5-1.5B-Instruct.yaml
Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
Qwen1.5-MoE-W4A16-compressed-tensors.yaml

View File

@@ -26,15 +26,14 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
GroupShape,
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed
from ...utils import has_module_attribute, multi_gpu_test
from ...utils import TestFP8Layer, has_module_attribute, multi_gpu_test
from ..backend import TestBackend
@@ -76,49 +75,40 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.w = [
torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype())
.t()
for _ in range(3)
self.fp8_linear_layers = [
TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
)
for i in range(3)
]
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
z2 = self.fp8_linear_layers[0](y)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
z3 = self.fp8_linear_layers[1](y2)
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
z4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
z4 = self.fp8_linear_layers[2](y3)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
@@ -130,7 +120,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default
if self.fp8_linear.quant_fp8.enabled()
if self.fp8_linear_layers[0].is_quant_fp8_enabled()
else torch.ops.aten.reciprocal.default,
]

View File

@@ -27,13 +27,14 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed
from ...utils import multi_gpu_test
from ...utils import TestFP8Layer, multi_gpu_test
from ..backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
@@ -94,50 +95,40 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, eps=1e-6):
super().__init__()
self.vllm_config = get_current_vllm_config()
self.hidden_size = hidden_size
self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.w = [
torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype())
.t()
for _ in range(3)
self.fp8_linear_layers = [
TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
)
for i in range(3)
]
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
z2 = self.fp8_linear_layers[0](y)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
z3 = self.fp8_linear_layers[1](y2)
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
z4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
z4 = self.fp8_linear_layers[2](y3)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
@@ -160,7 +151,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
return [
torch.ops._C.fused_add_rms_norm.default,
]
elif self.fp8_linear.quant_fp8.enabled():
elif any(layer.is_quant_fp8_enabled() for layer in self.fp8_linear_layers):
return [
torch.ops._C.static_scaled_fp8_quant.default,
]

View File

@@ -20,11 +20,13 @@ from vllm.config import (
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
from ..utils import TestFP8Layer
from .backend import TestBackend
TEST_FP8 = current_platform.supports_fp8()
@@ -32,24 +34,22 @@ FP8_DTYPE = current_platform.fp8_dtype()
class TestSiluMul(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size: int = 128):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)
if TEST_FP8:
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
self.fp8_linear = TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
)
def forward(self, x):
y = self.silu_and_mul(x)
if TEST_FP8:
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
return x2
return self.fp8_linear(y)
else:
return y
@@ -67,6 +67,8 @@ class TestSiluMul(torch.nn.Module):
class TestFusedAddRMSNorm(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
@@ -81,11 +83,11 @@ class TestFusedAddRMSNorm(torch.nn.Module):
torch.nn.init.normal_(self.gate_proj, std=0.02)
if TEST_FP8:
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
self.scale = torch.rand(1, dtype=torch.float32)
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32)
self.fp8_linear = TestFP8Layer(
weight_shape=(hidden_size, intermediate_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
)
def forward(self, hidden_states, residual):
# Reshape input
@@ -100,12 +102,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
if TEST_FP8:
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(
norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
)
fp8_linear_result = self.fp8_linear(norm_output)
return fp8_linear_result, residual_output

View File

@@ -5,6 +5,7 @@
import pytest
import torch
import vllm.config
import vllm.plugins
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
@@ -20,8 +21,22 @@ from vllm.config import (
VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
ChannelWiseTorchFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
@@ -29,15 +44,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
ScaleDesc,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_block_fp8_supported,
cutlass_fp8_supported,
maybe_create_device_identity,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.deep_gemm import (
is_deep_gemm_supported,
)
from ..utils import override_cutlass_fp8_supported
from ..utils import TestBlockFP8Layer, TestFP8Layer
from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
@@ -45,157 +59,195 @@ FP8_DTYPE = current_platform.fp8_dtype()
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Kernel and group_shape combinations: (kernel, group_shape)
# CUDA kernels
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
# FlashInferFP8ScaledMMLinearKernel supports both per-tensor only
(FlashInferFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
# CutlassFP8ScaledMMLinearKernel supports both per-tensor and per-token
(CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
(CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
# PerTensorTorchFP8ScaledMMLinearKernel only supports per-tensor
(PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
# Blockwise group shapes (no kernel abstraction)
(None, GroupShape(1, 128)),
(None, GroupShape(1, 64)),
]
# ROCm kernels
ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
# ROCmFP8ScaledMMLinearKernel supports per-tensor only
(ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
# RowWiseTorchFP8ScaledMMLinearKernel only supports per-token
(RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
# Blockwise group shapes (no kernel abstraction)
(None, GroupShape(1, 128)),
(None, GroupShape(1, 64)),
]
KERNEL_GROUPSHAPE_COMBINATIONS = (
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS
if current_platform.is_cuda()
else ROCM_KERNEL_GROUPSHAPE_COMBINATIONS
)
# For Aiter tests we toggle use_aiter_quant_op
AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
# Per-token with ROCmFP8ScaledMMLinearKernel
(ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR, False),
# Per-token with RowWiseTorchFP8ScaledMMLinearKernel
(RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
(RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
# Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
# Blockwise (no kernel abstraction)
(None, GroupShape(1, 128), True),
]
class TestModel(torch.nn.Module):
def __init__(
self,
hidden_size: int,
eps: float,
force_kernel: FP8ScaledMMLinearKernel | None,
group_shape: GroupShape,
use_aiter: bool = False,
cuda_force_torch: bool = False,
use_aiter_quant_op: bool = True,
use_aiter_fusion: bool = False,
use_aiter_quant: bool = False,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.use_aiter = use_aiter
self.use_aiter_quant_op = use_aiter_quant_op
self.cuda_force_torch = cuda_force_torch
self.fp8_linear_layers: list[torch.nn.Module]
self.group_shape = group_shape
self.enable_quant_fp8_custom_op = None # Will be set later if applicable
self.use_aiter_quant_op = use_aiter_quant
self.use_aiter_fusion = use_aiter_fusion
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
# Setup quantization scale descriptor
static = group_shape == GroupShape.PER_TENSOR and not use_aiter
quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
# Setup scales
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
else:
self.scale = [None for _ in range(3)]
# Setup weights
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
]
if not group_shape.is_per_group() or use_aiter:
self.w = [self.w[0].t() for _ in range(3)]
# Setup weight scales
if group_shape.is_per_group():
scale_size = (
(hidden_size + 128 - 1) // 128
if use_aiter
else hidden_size // group_shape[1]
)
wscale_shape: tuple[int, ...] = (scale_size, scale_size)
else:
wscale_shape = (1,)
self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)]
# Setup FP8 linear operation
is_per_group = group_shape.is_per_group()
if is_per_group and use_aiter:
self.fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=group_shape,
use_aiter_and_is_supported=use_aiter_quant_op,
)
# AITER blockwise doesn't use enable_quant_fp8_custom_op
elif is_per_group:
self.fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
act_quant_group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
use_aiter_and_is_supported=False,
)
self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
elif use_aiter:
self.fp8_linear = Fp8LinearOp(
act_quant_static=False,
act_quant_group_shape=group_shape,
)
self.fp8_linear.quant_fp8.use_aiter = use_aiter_quant_op
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
else:
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp(
act_quant_static=static,
act_quant_group_shape=group_shape,
)
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
self.enable_rms_norm_custom_op = self.norm[0].enabled()
# Determine if blockwise based on group_shape
is_blockwise = group_shape.is_per_group()
if is_blockwise:
act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
self.activation_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
)
self.fp8_linear_layers = [
TestBlockFP8Layer(
weight_shape=(hidden_size, hidden_size),
group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
use_aiter_and_is_supported=use_aiter_quant,
transpose_weights=use_aiter_fusion,
)
for _ in range(3)
]
self.enable_quant_fp8_custom_op = (
False
if use_aiter_quant
else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
)
else:
is_static = group_shape == GroupShape.PER_TENSOR
act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
self.activation_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
)
self.weight_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
)
self.fp8_linear_layers = [
TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
force_kernel=force_kernel,
)
for _ in range(3)
]
# Enable aiter quantization if requested
for layer in self.fp8_linear_layers:
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
0
].is_quant_fp8_enabled()
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = self.norm[0](x)
x2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
x2 = self.fp8_linear_layers[0](y)
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)
x3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
x3 = self.fp8_linear_layers[1](y2)
y3, resid = self.norm[2](x3, resid) # use resid here
x4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
x4 = self.fp8_linear_layers[2](y3)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_before(self):
if (
self.use_aiter
and self.group_shape.is_per_group()
and current_platform.is_fp8_fnuz()
):
return [rocm_aiter_ops.get_group_quant_op()]
if self.use_aiter and self.group_shape.is_per_group():
return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
if self.use_aiter and self.use_aiter_quant_op:
return [rocm_aiter_ops.get_per_token_quant_op()]
if self.use_aiter:
return [QUANT_OPS[self.quant_key]]
if self.enable_quant_fp8_custom_op:
return [QUANT_OPS[self.quant_key]]
return [torch.ops.aten.reciprocal]
if self.group_shape.is_per_group():
# Blockwise path
if self.use_aiter_fusion and self.use_aiter_quant_op:
return [rocm_aiter_ops.get_group_quant_op()]
if self.use_aiter_fusion:
return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
else:
if self.use_aiter_quant_op:
return [rocm_aiter_ops.get_per_token_quant_op()]
# Common path
return (
[QUANT_OPS[self.activation_quant_key]]
if self.enable_quant_fp8_custom_op
else [torch.ops.aten.reciprocal]
)
def ops_in_model_after(self):
if self.use_aiter and self.group_shape.is_per_group():
from vllm.compilation.rocm_aiter_fusion import (
AiterFusedAddRMSFp8GroupQuantPattern,
AiterRMSFp8GroupQuantPattern,
)
if self.use_aiter_fusion:
if self.group_shape.is_per_group():
# Blockwise aiter fusion
from vllm.compilation.rocm_aiter_fusion import (
AiterFusedAddRMSFp8GroupQuantPattern,
AiterRMSFp8GroupQuantPattern,
)
return [
AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
AiterRMSFp8GroupQuantPattern.FUSED_OP,
]
if self.use_aiter:
from vllm.compilation.rocm_aiter_fusion import (
AiterFusedAddRMSNormDynamicQuantPattern,
AiterRMSNormDynamicQuantPattern,
)
return [
AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
AiterRMSFp8GroupQuantPattern.FUSED_OP,
]
else:
# Per-token aiter fusion
from vllm.compilation.rocm_aiter_fusion import (
AiterFusedAddRMSNormDynamicQuantPattern,
AiterRMSNormDynamicQuantPattern,
)
return [
AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
AiterRMSNormDynamicQuantPattern.FUSED_OP,
]
return [
AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
AiterRMSNormDynamicQuantPattern.FUSED_OP,
]
# Regular fusion
return [
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)],
FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)],
]
def ops_in_model_before_partial(self):
@@ -206,14 +258,6 @@ class TestModel(torch.nn.Module):
)
GROUP_SHAPES = [
GroupShape.PER_TOKEN,
GroupShape.PER_TENSOR,
GroupShape(1, 128),
GroupShape(1, 64),
]
def _run_fusion_test(
model,
fusion_pass,
@@ -259,14 +303,9 @@ def _run_fusion_test(
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
@pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS)
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize(
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
@@ -275,11 +314,12 @@ def test_fusion_rmsnorm_quant(
hidden_size,
num_tokens,
eps,
group_shape,
kernel_groupshape,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
cuda_force_torch,
):
force_kernel, group_shape = kernel_groupshape
if not enable_quant_fp8_custom_op and group_shape.is_per_group():
pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")
@@ -310,15 +350,16 @@ def test_fusion_rmsnorm_quant(
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity()
fusion_pass = RMSNormQuantFusionPass(vllm_config)
model = TestModel(
hidden_size=hidden_size,
eps=eps,
force_kernel=force_kernel,
group_shape=group_shape,
use_aiter=False,
cuda_force_torch=cuda_force_torch,
use_aiter_fusion=False,
use_aiter_quant=False,
)
backend, _ = _run_fusion_test(
@@ -339,19 +380,12 @@ def test_fusion_rmsnorm_quant(
assert n_add_nodes(backend.graph_post_pass) == 2
GROUP_SHAPE_QUANT_OPS_MATCHS = [
(GroupShape.PER_TOKEN, True),
(GroupShape.PER_TOKEN, False),
(GroupShape(1, 128), True),
]
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize(
"group_shape, use_aiter_quant_op", GROUP_SHAPE_QUANT_OPS_MATCHS
"kernel_groupshape_quant", AITER_KERNEL_GROUPSHAPE_COMBINATIONS
)
@pytest.mark.skipif(
(not current_platform.is_rocm() or not IS_AITER_FOUND),
@@ -362,10 +396,10 @@ def test_aiter_fusion_rmsnorm_quant(
hidden_size: int,
num_tokens: int,
eps: float,
group_shape: GroupShape,
use_aiter_quant_op: bool,
kernel_groupshape_quant: tuple,
monkeypatch: pytest.MonkeyPatch,
):
force_kernel, group_shape, use_aiter_quant_op = kernel_groupshape_quant
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
@@ -379,20 +413,22 @@ def test_aiter_fusion_rmsnorm_quant(
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass
m.setenv("VLLM_ROCM_USE_AITER", "1")
rocm_aiter_ops.refresh_env_variables()
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity()
fusion_pass = RocmAiterRMSNormFusionPass(vllm_config)
model = TestModel(
hidden_size=hidden_size,
eps=eps,
force_kernel=force_kernel,
group_shape=group_shape,
use_aiter=True,
use_aiter_quant_op=use_aiter_quant_op,
use_aiter_fusion=True, # Always use aiter fusion ops in aiter test
use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization
)
_run_fusion_test(

View File

@@ -45,7 +45,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -53,6 +52,8 @@ from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import AttentionSpec
from ..utils import TestFP8Layer
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
@@ -185,32 +186,30 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.quant_key.scale.static,
act_quant_group_shape=self.quant_key.scale.group_shape,
hidden_size = self.num_qo_heads * self.head_size
self.fp8_linear = TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
device=self.device,
)
hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get(
"w",
{
"weight": torch.randn(hidden_size, hidden_size)
.to(dtype=FP8_DTYPE, device=self.device)
.t(),
"wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
"scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
},
)
w = kwargs.get("w")
if w is not None:
self.fp8_linear.weight = w["weight"]
self.fp8_linear.weight_scale = w["wscale"]
self.fp8_linear.input_scale = w["scale"]
self.w = {
"weight": self.fp8_linear.weight,
"wscale": self.fp8_linear.weight_scale,
"scale": self.fp8_linear.input_scale,
}
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v)
return self.fp8_linear.apply(
input=attn_output,
weight=self.w["weight"],
weight_scale=self.w["wscale"],
input_scale=self.w["scale"],
)
return self.fp8_linear(attn_output)
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):

View File

@@ -25,19 +25,30 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
PerTensorTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
maybe_create_device_identity,
)
from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported
from ..utils import TestFP8Layer
from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
@@ -49,25 +60,27 @@ def is_nvfp4_supported():
class TestSiluMulFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
quant_key = kFp8StaticTensorSym
def __init__(
self, hidden_size: int, force_kernel: FP8ScaledMMLinearKernel, **kwargs
):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
force_kernel=force_kernel,
)
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled()
def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
x2 = self.fp8_linear(y)
return x2
def ops_in_model_before(self):
@@ -161,20 +174,27 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
ROCM_KERNELS = [ROCmFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel]
CUDA_KERNELS = [
FlashInferFP8ScaledMMLinearKernel,
CutlassFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
]
TEST_KERNELS = ROCM_KERNELS if current_platform.is_rocm() else CUDA_KERNELS
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False])
@pytest.mark.parametrize(
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
"model_class, enable_quant_fp8_custom_op, force_kernel",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], TEST_KERNELS))
+ [
(TestSiluMulNvfp4QuantModel, False, False),
(TestSiluMulGroupFp8QuantModel, False, False),
(TestSiluMulNvfp4QuantModel, False, None),
(TestSiluMulGroupFp8QuantModel, False, None),
],
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.skipif(
envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
)
@@ -189,7 +209,7 @@ def test_fusion_silu_and_mul_quant(
],
enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool,
cuda_force_torch: bool,
force_kernel: FP8ScaledMMLinearKernel | None,
):
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
pytest.skip("NVFP4 is not supported on this GPU.")
@@ -198,7 +218,6 @@ def test_fusion_silu_and_mul_quant(
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
maybe_create_device_identity()
x = torch.rand(num_tokens, hidden_size * 2)
@@ -227,9 +246,7 @@ def test_fusion_silu_and_mul_quant(
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
)
model = model_class(hidden_size=hidden_size, force_kernel=force_kernel, x=x)
# First dimension dynamic
torch._dynamo.mark_dynamic(x, 0)

View File

@@ -11,13 +11,13 @@ from abc import ABC
import pytest
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig,
Int8ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel,
AiterInt8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
CPUScaledMMLinearKernel,
CPUInt8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearKernel,
@@ -33,36 +33,38 @@ def test_is_supported_is_abstract():
def test_cpu_kernel_implements_is_supported():
"""Test that CPUScaledMMLinearKernel implements is_supported() method."""
assert hasattr(CPUScaledMMLinearKernel, "is_supported"), (
"CPUScaledMMLinearKernel missing is_supported() method"
"""Test that CPUInt8ScaledMMLinearKernel implements is_supported() method."""
assert hasattr(CPUInt8ScaledMMLinearKernel, "is_supported"), (
"CPUInt8ScaledMMLinearKernel missing is_supported() method"
)
# Verify it's a classmethod by checking if it can be called with the class
# and by checking the method type
assert inspect.ismethod(CPUScaledMMLinearKernel.is_supported) or inspect.isfunction(
CPUScaledMMLinearKernel.is_supported
), "CPUScaledMMLinearKernel.is_supported() should be a classmethod"
assert inspect.ismethod(
CPUInt8ScaledMMLinearKernel.is_supported
) or inspect.isfunction(CPUInt8ScaledMMLinearKernel.is_supported), (
"CPUInt8ScaledMMLinearKernel.is_supported() should be a classmethod"
)
# Verify it can be called as a classmethod
result, reason = CPUScaledMMLinearKernel.is_supported()
result, reason = CPUInt8ScaledMMLinearKernel.is_supported()
assert isinstance(result, bool), "is_supported() should return a bool"
assert reason is None or isinstance(reason, str), "reason should be str or None"
def test_aiter_kernel_implements_is_supported():
"""Test that AiterScaledMMLinearKernel implements is_supported() method."""
assert hasattr(AiterScaledMMLinearKernel, "is_supported"), (
"AiterScaledMMLinearKernel missing is_supported() method"
"""Test that AiterInt8ScaledMMLinearKernel implements is_supported() method."""
assert hasattr(AiterInt8ScaledMMLinearKernel, "is_supported"), (
"AiterInt8ScaledMMLinearKernel missing is_supported() method"
)
# Verify it's a classmethod by checking if it can be called with the class
# and by checking the method type
assert inspect.ismethod(
AiterScaledMMLinearKernel.is_supported
) or inspect.isfunction(AiterScaledMMLinearKernel.is_supported), (
"AiterScaledMMLinearKernel.is_supported() should be a classmethod"
AiterInt8ScaledMMLinearKernel.is_supported
) or inspect.isfunction(AiterInt8ScaledMMLinearKernel.is_supported), (
"AiterInt8ScaledMMLinearKernel.is_supported() should be a classmethod"
)
# Verify it can be called as a classmethod
# (will return False on CPU, which is expected)
result, reason = AiterScaledMMLinearKernel.is_supported()
result, reason = AiterInt8ScaledMMLinearKernel.is_supported()
assert isinstance(result, bool), "is_supported() should return a bool"
assert reason is None or isinstance(reason, str), "reason should be str or None"
# On CPU, it should return False with a reason about requiring ROCm
@@ -70,14 +72,14 @@ def test_aiter_kernel_implements_is_supported():
def test_cpu_kernel_accepts_all_configs():
"""Test that CPUScaledMMLinearKernel accepts all config combinations."""
"""Test that CPUInt8ScaledMMLinearKernel accepts all config combinations."""
configs = [
ScaledMMLinearLayerConfig(
Int8ScaledMMLinearLayerConfig(
is_channelwise=False,
is_static_input_scheme=True,
input_symmetric=True,
),
ScaledMMLinearLayerConfig(
Int8ScaledMMLinearLayerConfig(
is_channelwise=True,
is_static_input_scheme=False,
input_symmetric=False,
@@ -85,7 +87,7 @@ def test_cpu_kernel_accepts_all_configs():
]
for config in configs:
can_impl, reason = CPUScaledMMLinearKernel.can_implement(config)
can_impl, reason = CPUInt8ScaledMMLinearKernel.can_implement(config)
assert can_impl, (
f"CPUScaledMMLinearKernel should accept config {config}: {reason}"
f"CPUInt8ScaledMMLinearKernel should accept config {config}: {reason}"
)

View File

@@ -41,7 +41,7 @@ ROCM_AITER_SUPPORTED_INT8_MODEL = [
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
]
# TritonScaledMMLinearKernel only supports symmetric quantization.
# TritonInt8ScaledMMLinearKernel only supports symmetric quantization.
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",

View File

@@ -42,6 +42,17 @@ from vllm.distributed import (
)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.cli.serve import ServeSubcommand
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
)
from vllm.model_executor.model_loader import get_model_loader
from vllm.platforms import current_platform
from vllm.tokenizers import get_tokenizer
@@ -50,6 +61,8 @@ from vllm.utils.mem_constants import GB_bytes
from vllm.utils.network_utils import get_open_port
from vllm.utils.torch_utils import cuda_device_count_stateless
FP8_DTYPE = current_platform.fp8_dtype()
if current_platform.is_rocm():
from amdsmi import (
amdsmi_get_gpu_vram_usage,
@@ -1332,3 +1345,117 @@ def flat_product(*iterables: Iterable[Any]):
for element in itertools.product(*iterables):
normalized = (e if isinstance(e, tuple) else (e,) for e in element)
yield tuple(itertools.chain(*normalized))
class TestFP8Layer(torch.nn.Module):
"""
Test helper for FP8 linear operations. Creates random weights and scales
based on quantization configuration.
Args:
weight_shape: Shape of the weight tensor (out_features, in_features).
activation_quant_key: Activation quantization configuration.
weight_quant_key: Weight quantization configuration.
out_dtype: Output dtype. Defaults to current default dtype.
force_kernel: Optional kernel to force use of specific implementation.
"""
def __init__(
self,
weight_shape: tuple[int, int],
activation_quant_key: QuantKey,
weight_quant_key: QuantKey,
out_dtype: torch.dtype | None = None,
device: torch.device | None = None,
force_kernel: FP8ScaledMMLinearKernel | None = None,
):
super().__init__()
per_tensor_weights = weight_quant_key.scale.group_shape.is_per_tensor()
is_static_activation_scale = activation_quant_key.scale.static
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
self.weight_scale = torch.rand(
weight_scale_shape, dtype=torch.float32, device=device
)
self.input_scale = (
torch.rand(1, dtype=torch.float32, device=device)
if is_static_activation_scale
else None
)
self.weight = torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
self.input_scale_ub = None
out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype
self.kernel = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key,
out_dtype=out_dtype,
force_kernel=force_kernel,
)
def is_quant_fp8_enabled(self) -> bool:
return self.kernel.quant_fp8.enabled()
def forward(
self, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return self.kernel.apply_weights(self, y, bias)
# TODO: Drop TestBlockFP8Layer in favour of a unified TestFP8Layer
# after refactoring W8A8BlockFp8LinearOp.
# https://github.com/vllm-project/vllm/issues/31818
class TestBlockFP8Layer:
"""
Test helper for blockwise FP8 linear operations. Creates random weights
and scales for W8A8BlockFp8LinearOp.
This is a workaround until W8A8BlockFp8LinearOp implements the kernel
abstraction (ScaledMMLinearKernel) for blockwise quantization.
Args:
weight_shape: Shape of the weight tensor (out_features, in_features).
group_shape: Blockwise quantization group shape.
cutlass_block_fp8_supported: Whether CUTLASS blockwise FP8 is available.
use_aiter_and_is_supported: Whether to use aiter quantization ops.
transpose_weights: Whether to transpose weights after creation.
"""
def __init__(
self,
weight_shape: tuple[int, int],
group_shape: GroupShape,
cutlass_block_fp8_supported: bool = False,
use_aiter_and_is_supported: bool = False,
transpose_weights: bool = False,
):
weight_scale_shape = weight_shape[0] // group_shape[1]
self.weight_scale = torch.rand(
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
)
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
self.input_scale = None
if transpose_weights:
self.weight = self.weight.t()
self.linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
act_quant_group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
use_aiter_and_is_supported=use_aiter_and_is_supported,
)
def __call__(
self, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return self.linear_op.apply(
input=y,
weight=self.weight,
weight_scale=self.weight_scale,
input_scale=self.input_scale,
bias=bias,
)
def is_quant_fp8_enabled(self) -> bool:
return self.linear_op.input_quant_op.enabled()

View File

@@ -372,7 +372,7 @@ def _rocm_aiter_gemm_a8w8_impl(
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)

View File

@@ -8,9 +8,13 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrate
from torch.nn import Parameter
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
@@ -22,11 +26,14 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_tensor_strategy,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_block_fp8_supported,
maybe_create_device_identity,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
@@ -42,6 +49,18 @@ strategy_to_parameter_type = {
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
}
STATIC_QUANT = True
DYNAMIC_QUANT = False
activation_quant_key_mapping = {
STATIC_QUANT: kFp8StaticTensorSym,
DYNAMIC_QUANT: kFp8DynamicTokenSym,
}
weight_quant_key_mapping = {
QuantizationStrategy.CHANNEL: kFp8StaticTokenSym,
QuantizationStrategy.TENSOR: kFp8StaticTensorSym,
}
logger = init_logger(__name__)
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
@@ -49,22 +68,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme
self.weight_block_size = self.weight_quant.block_structure
if self.weight_block_size is not None:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
self.act_q_group_shape = (
GroupShape.PER_TENSOR
if is_static_input_scheme
else GroupShape.PER_TOKEN
)
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
if self.weight_block_size is not None:
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
assert not self.is_static_input_scheme
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
@@ -72,9 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape,
activation_quant_key = activation_quant_key_mapping[is_static_input_scheme]
weight_quant_key = weight_quant_key_mapping[self.strategy]
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
@classmethod
@@ -93,8 +107,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_loader: Callable,
**kwargs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.weight_block_size = None
@@ -143,7 +155,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
getattr(layer, "input_scale", None),
)
weight = weight.t()
elif self.strategy == QuantizationStrategy.CHANNEL:
weight, weight_scale, input_scale = process_fp8_weight_channel_strategy(
layer.weight, layer.weight_scale, getattr(layer, "input_scale", None)
@@ -174,7 +185,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
else:
layer.input_scale = None
if self.strategy == QuantizationStrategy.BLOCK:
maybe_post_process_fp8_weight_block(layer)
@@ -193,11 +203,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
bias=bias,
)
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias)

View File

@@ -11,8 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig,
choose_scaled_mm_linear_kernel,
init_int8_linear_kernel,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
@@ -25,8 +24,6 @@ logger = init_logger(__name__)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
):
@@ -50,18 +47,13 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
):
layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
self.kernel = init_int8_linear_kernel(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=self.input_symmetric,
module_name=self.__class__.__name__,
)
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(
@@ -90,12 +82,12 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
input_zero_point = None
input_scale = None
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
)
layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
@@ -103,16 +95,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
)
layer.register_parameter("input_zero_point", input_zero_point)
self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
w_s_param_name="weight_scale",
i_s_param_name="input_scale",
i_zp_param_name="input_zero_point",
azp_adj_param_name="azp_adj",
)
layer.register_parameter("input_zero_point", input_zero_point)
layer.register_parameter("input_scale", input_scale)
if not hasattr(layer, "azp_adj"):
layer.register_parameter("azp_adj", None)
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.

View File

@@ -18,17 +18,19 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
is_layer_skipped,
kFp8DynamicTokenSym,
kFp8StaticTokenSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz,
)
from vllm.model_executor.parameter import (
@@ -91,10 +93,13 @@ class FBGEMMFp8Config(QuantizationConfig):
class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp(
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
)
self.out_dtype = torch.get_default_dtype()
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticTokenSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def create_weights(
self,
@@ -106,7 +111,6 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
weight_loader = extra_weight_attrs.get("weight_loader")
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
@@ -184,12 +188,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
bias=bias,
)
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=None,
input_scale_ub=layer.input_scale_ub,
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias)

View File

@@ -48,6 +48,9 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
@@ -76,12 +79,13 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
is_layer_skipped,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_block_fp8_supported,
cutlass_fp8_supported,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz,
)
from vllm.model_executor.parameter import (
@@ -328,28 +332,30 @@ class Fp8LinearMethod(LinearMethodBase):
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static"
if self.weight_block_size:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
if self.block_quant:
assert not self.act_q_static
assert self.weight_block_size is not None
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape,
# Use per-token quantization for better perf if dynamic and cutlass
if self.act_q_static:
activation_quant_key = kFp8StaticTensorSym
elif cutlass_fp8_supported():
activation_quant_key = kFp8DynamicTokenSym
else:
activation_quant_key = kFp8DynamicTensorSym
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=kFp8StaticTensorSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def create_weights(
@@ -362,8 +368,6 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
@@ -462,8 +466,6 @@ class Fp8LinearMethod(LinearMethodBase):
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
@@ -602,14 +604,7 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias,
)
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias)
class Fp8MoEMethod(FusedMoEMethodBase):

View File

@@ -2,19 +2,58 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Generic, TypeVar
import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.platforms import current_platform
@dataclass
class ScaledMMLinearLayerConfig:
is_channelwise: bool
pass
@dataclass
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
# TODO: Chnage to QuantKey like FP8ScaledMMLinearLayerConfig
is_static_input_scheme: bool
is_channelwise: bool
input_symmetric: bool
class ScaledMMLinearKernel(ABC):
@dataclass
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
weight_quant_key: QuantKey
activation_quant_key: QuantKey
out_dtype: torch.dtype | None
_FP8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_scale_ub,
]
_Int8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_zp
torch.Tensor | None, # azp_adj
]
_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT)
_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig)
class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC):
@classmethod
@abstractmethod
def is_supported(
@@ -24,26 +63,14 @@ class ScaledMMLinearKernel(ABC):
@classmethod
@abstractmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: _ConfigT) -> tuple[bool, str | None]:
raise NotImplementedError
def __init__(
self,
c: ScaledMMLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
i_s_param_name: str,
i_zp_param_name: str,
azp_adj_param_name: str,
) -> None:
assert self.can_implement(c)
assert self.is_supported()
def __init__(self, c: _ConfigT, layer_param_names: Sequence[str]) -> None:
assert self.can_implement(c)[0]
assert self.is_supported()[0]
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
self.i_s_name = i_s_param_name
self.i_zp_name = i_zp_param_name
self.azp_adj_name = azp_adj_param_name
self.layer_param_names = layer_param_names
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@@ -58,19 +85,103 @@ class ScaledMMLinearKernel(ABC):
) -> torch.Tensor:
raise NotImplementedError
def _get_weight_params(
self, layer: torch.nn.Module
) -> tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_zp
torch.Tensor | None, # azp_adj
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.i_s_name),
getattr(layer, self.i_zp_name),
getattr(layer, self.azp_adj_name),
# return a covariant type in the subclass
@abstractmethod
def _get_layer_params(self, layer) -> _ParamsT:
raise NotImplementedError
class FP8ScaledMMLinearKernel(
ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, _FP8ParamsT], ABC
):
def __init__(
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
) -> None:
act_scale_descriptor = c.activation_quant_key.scale
self.quant_fp8 = QuantFP8(
static=act_scale_descriptor.static,
group_shape=act_scale_descriptor.group_shape,
num_token_padding=self.get_output_padding(),
)
self.fp8_dtype = current_platform.fp8_dtype()
super().__init__(c, layer_param_names)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def _get_layer_params(self, layer) -> _FP8ParamsT:
w, w_s, x_s, x_s_ub = self.layer_param_names
return (
getattr(layer, w),
getattr(layer, w_s),
getattr(layer, x_s, None),
getattr(layer, x_s_ub, None),
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
fp8_dtype = self.fp8_dtype
maybe_out_dtype = self.config.out_dtype
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_s computed from x.
# If static, layer.input_scale is scalar and x_s is input_scale.
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
output_shape = [*x.shape[:-1], w.shape[1]]
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != fp8_dtype:
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
x_s_ub,
)
return self.apply_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)
@abstractmethod
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
raise NotImplementedError
def get_output_padding(self) -> int | None:
return None
class Int8ScaledMMLinearKernel(
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC
):
def _get_layer_params(self, layer) -> _Int8ParamsT:
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
return (
getattr(layer, w_q),
getattr(layer, w_s),
getattr(layer, i_s, None),
getattr(layer, i_zp, None),
getattr(layer, azp_adj, None),
)

View File

@@ -2,76 +2,229 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import TypeVar
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel,
AiterInt8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
CPUScaledMMLinearKernel,
CPUInt8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassScaledMMLinearKernel,
CutlassFP8ScaledMMLinearKernel,
CutlassInt8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
ChannelWiseTorchFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel,
TritonInt8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms import PlatformEnum, current_platform
logger = init_logger(__name__)
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUInt8ScaledMMLinearKernel],
PlatformEnum.CUDA: [
CutlassInt8ScaledMMLinearKernel,
TritonInt8ScaledMMLinearKernel,
],
PlatformEnum.ROCM: [AiterInt8ScaledMMLinearKernel, TritonInt8ScaledMMLinearKernel],
}
# in priority/performance order (when available)
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
PlatformEnum.CUDA: [
FlashInferFP8ScaledMMLinearKernel,
CutlassFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
],
PlatformEnum.ROCM: [
ROCmFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
],
PlatformEnum.CPU: [
PerTensorTorchFP8ScaledMMLinearKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
],
}
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
def is_supported_and_can_implement_kernel(
kernel: type[_KernelT], config: _KernelConfigT, compute_capability: int | None
) -> tuple[bool, str]:
# TODO: Fetch `VLLM_DISABLED_KERNELS` from vllm.envs instead.
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
return False, f" {kernel.__name__} is disabled by environment variable"
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc[0] * 10 + _cc[1]
is_supported, failure_reason = kernel.is_supported(compute_capability)
if not is_supported:
return False, f"{kernel.__name__} {failure_reason}."
can_implement, failure_reason = kernel.can_implement(config)
if not can_implement:
return (
False,
f"{kernel.__name__} {failure_reason}.",
)
return True, ""
def choose_scaled_mm_linear_kernel(
config: ScaledMMLinearLayerConfig, compute_capability: int | None = None
) -> type[ScaledMMLinearKernel]:
config: _KernelConfigT,
possible_kernels: dict[PlatformEnum, list[type[_KernelT]]],
compute_capability: int | None = None,
force_kernel: type[_KernelT] | None = None,
) -> type[_KernelT]:
"""
Choose an ScaledMMLinearKernel that can implement the given config for the
Choose a _KernelT that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (ScaledMMLinearLayerConfig): Description of the linear layer
config (_KernelConfigT): Description of the linear layer
to be implemented.
possible_kernels (dict[PlatformEnum, list[_KernelT]]): A
dictionary of platforms and their list list of possible kernels.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the
compute capability. Defaults to None.
force_kernel (Optional[type[_KernelT]]): An Optional forced kernel to override
the possible_kernels if it can be implemented. If None, it will only try the
possible kernels.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
type[ScaledMMLinearKernel]: Chosen kernel.
_KernelT: Chosen kernel.
"""
failure_reasons = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
failure_reasons.append(f"{kernel.__name__}: disabled by env var")
continue
failure_reason_list = []
# If the current platform uses compute_capability,
# make sure the kernel supports the compute capability.
is_supported, reason = kernel.is_supported(compute_capability)
if not is_supported:
failure_reasons.append(f"{kernel.__name__}: {reason}")
continue
if force_kernel is not None:
can_implement, failure_reason = is_supported_and_can_implement_kernel(
force_kernel, config, compute_capability
)
if can_implement:
return force_kernel
can_implement, reason = kernel.can_implement(config)
if not can_implement:
failure_reasons.append(f"{kernel.__name__}: {reason}")
continue
logger.info_once(
"Tried to force %s, but the kernel couldn't be implemented",
force_kernel.__name__,
scope="global",
)
return kernel
for kernel in possible_kernels[current_platform._enum]:
is_supported_and_can_implement, failure_reason = (
is_supported_and_can_implement_kernel(kernel, config, compute_capability)
)
if is_supported_and_can_implement:
return kernel
failure_reason_list.append(failure_reason)
raise ValueError(
"Failed to find a kernel that can implement the "
"ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons)
"ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reason_list)
)
def init_fp8_linear_kernel(
activation_quant_key: QuantKey,
weight_quant_key: QuantKey,
out_dtype: torch.dtype,
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
module_name: str | None = None,
) -> FP8ScaledMMLinearKernel:
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
weight_quant_key=weight_quant_key,
activation_quant_key=activation_quant_key,
out_dtype=out_dtype,
)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel
)
if module_name:
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
)
return kernel_type(
scaled_mm_linear_kernel_config,
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
)
def init_int8_linear_kernel(
is_channelwise: bool,
is_static_input_scheme: bool,
input_symmetric: bool,
module_name: str,
) -> Int8ScaledMMLinearKernel:
config = Int8ScaledMMLinearLayerConfig(
is_channelwise=is_channelwise,
is_static_input_scheme=is_static_input_scheme,
input_symmetric=input_symmetric,
)
kernel_type = choose_scaled_mm_linear_kernel(
config,
_POSSIBLE_INT8_KERNELS,
)
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
)
return kernel_type(
config,
layer_param_names=[
"weight",
"weight_scale",
"input_scale",
"input_zero_point",
"azp_adj",
],
)

View File

@@ -8,60 +8,41 @@ from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.platforms import current_platform
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
from .cutlass import CutlassInt8ScaledMMLinearKernel
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+ "currently supported on non-ROCm platform.",
)
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor
return False, "Requires ROCm."
if compute_capability is not None and compute_capability < 90:
return False, f"requires capability 90, got {compute_capability}"
return False, "requires compute capability 90 and above."
try:
import aiter # noqa: F401 # deliberately attempt to import aiter
except Exception:
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+ "installed on ROCm.",
)
return False, "requires `aiter` to be installed."
if not rocm_aiter_ops.is_linear_enabled():
return (
False,
"AiterScaledMMLinearKernel is disabled. "
+ "Enable by setting `VLLM_ROCM_USE_AITER=1` "
"requires setting `VLLM_ROCM_USE_AITER=1` "
+ "and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
+ "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
)
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not c.input_symmetric:
return (
False,
"AiterScaledMMLinearKernel only supports symmetric " + "quantization.",
)
return False, "supports symmetric quantization only."
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
def apply_weights(
self,
layer: torch.nn.Module,
@@ -69,28 +50,28 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""
`AiterScaledMMLinearKernel` implements a fused version of
`AiterInt8ScaledMMLinearKernel` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
Currently only support per-tensor-per-tensor GEMM
and per-token-per-channel GEMM through AITER
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` also does not support
ATIER block scaled GEMM and mix-precision GEMM.
"""
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
assert symmetric, (
"AiterScaledMMLinearKernel only supports symmetric quantization."
"AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
)
x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric)
assert x_zp is None, (
"AiterScaledMMLinearKernel only supports symmetric quantization."
"AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
)
out_dtype = x.dtype
@@ -117,12 +98,12 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
), (
"Currently only support per-tensor-per-tensor GEMM "
+ " and per-token-per-channel GEMM through AITER"
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` "
" w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` "
+ "does not support AITER block scaled GEMM."
)
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)

View File

@@ -14,24 +14,28 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import (
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
)
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
class CPUInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "Requires CPU."
return False, "requires CPU."
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = getattr(layer, self.w_q_name)
w_q_name, _, _, _, _ = self.layer_param_names
weight = getattr(layer, w_q_name)
dtype = weight.dtype
N, K = weight.size()
if (
@@ -49,10 +53,11 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Transpose to [K, N] for convenience
weight = getattr(layer, self.w_q_name)
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
weight = getattr(layer, w_q_name)
replace_parameter(
layer,
self.w_q_name,
w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
@@ -61,28 +66,27 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
weight_scale = getattr(layer, w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
input_scale = getattr(layer, i_s_name)
if self.config.input_symmetric:
replace_parameter(
layer,
self.i_s_name,
i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
input_zero_point = getattr(layer, self.i_zp_name)
input_zero_point = getattr(layer, i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
@@ -92,20 +96,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
azp = (
(int8_traits.min - range_min / scale).round().to(dtype=torch.int32)
)
replace_parameter(
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
# Different from cutlass, oneDNN kernels only need the AZP adjustment
# term for dynamic quantization. And s_b should be folded into the
# term. Such as:
@@ -113,38 +113,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
# s_a * GEMM_output - s_a * zp_a * adj + bias
if not (self.config.input_symmetric and self.config.is_static_input_scheme):
weight = getattr(layer, self.w_q_name)
weight_scale = getattr(layer, self.w_s_name)
weight = getattr(layer, w_q_name)
weight_scale = getattr(layer, w_s_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
azp_adj = azp_adj * weight_scale.squeeze()
setattr(
layer,
self.azp_adj_name,
azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
else:
setattr(layer, self.azp_adj_name, None)
weight = getattr(layer, self.w_q_name)
weight = getattr(layer, w_q_name)
self.dnnl_handler = ops.create_onednn_scaled_mm(
weight,
getattr(layer, self.w_s_name),
getattr(layer, w_s_name),
torch.get_default_dtype(),
getattr(layer, self.i_s_name) is None,
getattr(layer, i_s_name) is None,
not self.config.input_symmetric,
32,
)
# weight is prepacked and maintained by the dnnl_handler,
# release the original weight
setattr(layer, self.w_q_name, None)
setattr(layer, w_q_name, None)
del weight
def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
w_q_name, w_s_name, _, _, _ = self.layer_param_names
# WEIGHT
weight = getattr(layer, self.w_q_name)
weight = getattr(layer, w_q_name)
packed_weight = torch.ops._C.convert_weight_packed(weight)
replace_parameter(
layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
layer, w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
)
if layer.bias is not None:
@@ -156,19 +155,15 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# WEIGHT SCALE
# CPU SGL kernels only support per-channel.
# For per-tensor quant, convert to the per-channel case.
weight_scale = getattr(layer, self.w_s_name)
weight_scale = getattr(layer, w_s_name)
if not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
def apply_weights(
self,
layer: torch.nn.Module,
@@ -187,7 +182,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
@@ -209,7 +204,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer)
w_q, w_s, _, _, _ = self._get_layer_params(layer)
return torch.ops._C.int8_scaled_mm_with_quant(
x,
w_q,

View File

@@ -11,35 +11,36 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
)
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "Requires CUDA."
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor
if compute_capability is not None and compute_capability < 75:
return False, f"requires capability 75, got {compute_capability}"
return False, "requires CUDA."
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
config = self.config
# WEIGHT
# Cutlass kernels need transposed weight.
weight = getattr(layer, self.w_q_name)
weight = getattr(layer, w_q_name)
replace_parameter(
layer,
self.w_q_name,
w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
@@ -48,28 +49,28 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = getattr(layer, w_s_name)
if is_fused_module and not config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
if config.is_static_input_scheme:
input_scale = getattr(layer, i_s_name)
if self.config.input_symmetric:
if config.input_symmetric:
replace_parameter(
layer,
self.i_s_name,
i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
setattr(layer, i_zp_name, None)
else:
input_zero_point = getattr(layer, self.i_zp_name)
input_zero_point = getattr(layer, i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
@@ -79,38 +80,32 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
# AZP loaded as int8 but used as int32
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
replace_parameter(
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
if not self.config.input_symmetric:
weight = getattr(layer, self.w_q_name)
if not config.input_symmetric:
weight = getattr(layer, w_q_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if self.config.is_static_input_scheme:
if config.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
azp_adj = getattr(layer, i_zp_name) * azp_adj
setattr(
layer,
self.azp_adj_name,
azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
else:
setattr(layer, self.azp_adj_name, None)
def apply_weights(
self,
@@ -118,7 +113,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
@@ -145,3 +140,34 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
return ops.cutlass_scaled_mm(
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
)
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "requires CUDA."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
return output.view(*output_shape)

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
class FlashInferFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "requires CUDA."
if not has_flashinfer():
return False, "requires FlashInfer to be installed."
if compute_capability is not None and compute_capability < 100:
return False, "requires compute capability 100 and above."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return False, "requires per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
return flashinfer_scaled_fp8_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)

View File

@@ -0,0 +1,221 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from packaging import version
from vllm.config import CompilationMode, get_current_vllm_config
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
class TorchFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
"""
Base class for FP8 linear kernels using Torch.
Each subclass represents a kernel variant for
specific device capabilities and torch versions.
"""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not (current_platform.is_cuda_alike() or current_platform.is_cpu()):
return False, "requires ROCm, CUDA or CPU."
if compute_capability is not None and compute_capability < 89:
return False, "requires compute capability 89 and above."
return True, None
def get_output_padding(self) -> int | None:
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
#
# The perf gain is still relevant as of 16/1/2026
# torch version == 2.9.0. More details in the link below:
# https://github.com/vllm-project/vllm/issues/32269
vllm_config = get_current_vllm_config().compilation_config
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
return 17 if pad_output else None
class PerTensorTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return False, "requires per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
output = torch._scaled_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return False, "requires ROCm."
from vllm.platforms.rocm import on_mi3xx
if not on_mi3xx():
return False, "requires MI3xx."
if compute_capability is not None and compute_capability < 94:
return False, "requires compute capability 94 and above."
if not version.parse(torch.__version__) >= version.parse("2.7"):
return False, "requires pytorch version >=2.7."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if c.out_dtype == torch.float16:
# hipblaslt rowwise _scaled_mm only supports BFloat16
return False, "supports BFloat16 output data type only."
if per_tensor_activation_scales or per_tensor_weight_scales:
return False, "cannot be used with per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
# Note:
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
#
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
A,
B,
out_dtype=out_dtype,
scale_a=As,
scale_b=Bs.t(),
bias=bias,
)
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if per_tensor_activation_scales and per_tensor_weight_scales:
return False, "cannot be used with per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as scales
dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device)
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(
A,
B,
scale_a=dummy_tensor,
scale_b=dummy_tensor,
out_dtype=torch.float32,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, output_shape[0])
x_scale = torch.narrow(As, 0, 0, output_shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * Bs.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)

View File

@@ -0,0 +1,117 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.platform_utils import get_cu_count
from vllm.utils.torch_utils import direct_register_custom_op
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
if (
A.shape[0] == 1
and B.shape[1] % 16 == 0
and ((bias is None) or (bias.dtype == out_dtype))
):
output = ops.wvSplitKQ(
B.t(),
A,
out_dtype,
As,
Bs,
get_cu_count(),
bias,
)
# Fallback
else:
output = torch._scaled_mm(
A,
B,
out_dtype=out_dtype,
scale_a=As,
scale_b=Bs,
bias=bias,
)
return output
def rocm_per_tensor_float_w8a8_scaled_mm_fake(
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
)
class ROCmFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return False, "requires ROCm."
from vllm.platforms.rocm import on_mi3xx
if not on_mi3xx():
return False, "requires MI3xx."
if not envs.VLLM_ROCM_USE_SKINNY_GEMM:
return False, "requires VLLM_ROCM_USE_SKINNY_GEMM to be enabled."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return False, "requires per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl(
A, B, out_dtype, As, Bs, bias
)
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)

View File

@@ -14,30 +14,35 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
from .cutlass import CutlassInt8ScaledMMLinearKernel
from .ScaledMMLinearKernel import (
Int8ScaledMMLinearLayerConfig,
)
class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
class TritonInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if current_platform.is_cuda_alike():
return True, None
return False, "Requires ROCm or CUDA."
return False, "requires ROCm or CUDA."
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not c.input_symmetric:
return False, "Only symmetric input is supported."
return False, "supports symmetric input only."
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = getattr(layer, self.w_q_name)
w_q, _, i_s, _, _ = self._get_layer_params(layer)
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
replace_parameter(
layer,
self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
w_q_name,
torch.nn.Parameter(w_q.t().data, requires_grad=False),
)
# WEIGHT SCALE
@@ -45,29 +50,29 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
weight_scale = getattr(layer, w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
assert i_s is not None
replace_parameter(
layer,
self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
i_s_name,
torch.nn.Parameter(i_s.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
setattr(layer, i_zp_name, None)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, i_s_name, None)
setattr(layer, i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
setattr(layer, azp_adj_name, None)
def apply_weights(
self,
@@ -75,7 +80,7 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
w_q, w_s, i_s, i_zp, _ = self._get_layer_params(layer)
x_q, x_s, x_zp = ops.scaled_int8_quant(
x.contiguous(), i_s, i_zp, symmetric=True

View File

@@ -49,6 +49,9 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
@@ -78,10 +81,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
cutlass_fp4_supported,
is_layer_skipped,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_block_fp8_supported,
requantize_with_max_scale,
)
@@ -438,8 +443,11 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp(
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8StaticTensorSym,
weight_quant_key=kFp8StaticTensorSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def create_weights(
@@ -507,13 +515,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias)
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
@@ -527,8 +529,11 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp(
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticTokenSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def create_weights(
@@ -585,13 +590,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias)
class ModelOptFp8PbWoLinearMethod(LinearMethodBase):

View File

@@ -17,11 +17,13 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8KVCacheMethod,
Fp8LinearMethod,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
is_layer_skipped,
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
kFp8DynamicTokenSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform
ACTIVATION_SCHEMES = ["static", "dynamic"]
@@ -97,9 +99,11 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
)
super().__init__(quant_config=quant_config)
# Force weight quantization
self.quant_config.is_checkpoint_fp8_serialized = False
self.fp8_linear = Fp8LinearOp(
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8DynamicTokenSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@@ -130,11 +134,4 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
input_scale_ub=None,
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias)

View File

@@ -7,10 +7,18 @@ from typing import Any, cast
import torch
from torch.nn import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale,
)
@@ -23,6 +31,8 @@ from vllm.platforms import current_platform
__all__ = ["QuarkW8A8Fp8"]
logger = init_logger(__name__)
class QuarkW8A8Fp8(QuarkScheme):
def __init__(
@@ -35,15 +45,16 @@ class QuarkW8A8Fp8(QuarkScheme):
self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic"))
self.input_qscheme = cast(str, input_config.get("qscheme"))
per_token = (
per_token_activation = (
not self.is_static_input_scheme and self.input_qscheme == "per_channel"
)
self.act_quant_group_shape = (
GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
per_token_weight = self.weight_qscheme == "per_channel"
self.activation_quant_key = (
kFp8DynamicTokenSym if per_token_activation else kFp8StaticTensorSym
)
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_quant_group_shape,
self.weight_quant_key = (
kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym
)
self.out_dtype = torch.get_default_dtype()
@@ -94,7 +105,7 @@ class QuarkW8A8Fp8(QuarkScheme):
layer.input_scale = Parameter(input_scale, requires_grad=False)
else:
weight_scale = layer.weight_scale.data
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
if self.activation_quant_key.scale.group_shape == GroupShape.PER_TOKEN:
weight_scale = weight_scale.view(-1, 1)
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
@@ -106,8 +117,6 @@ class QuarkW8A8Fp8(QuarkScheme):
# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
else:
layer.input_scale = None
def create_weights(
self,
@@ -163,17 +172,17 @@ class QuarkW8A8Fp8(QuarkScheme):
input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale)
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias)

View File

@@ -7,8 +7,7 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig,
choose_scaled_mm_linear_kernel,
init_int8_linear_kernel,
)
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.parameter import (
@@ -22,8 +21,6 @@ logger = init_logger(__name__)
class QuarkW8A8Int8(QuarkScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(
self,
qscheme: str,
@@ -50,18 +47,13 @@ class QuarkW8A8Int8(QuarkScheme):
):
layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
self.kernel = init_int8_linear_kernel(
is_channelwise=(self.qscheme == "per_channel"),
is_static_input_scheme=(self.is_static_input_scheme is True),
input_symmetric=(self.input_symmetric is True),
module_name=self.__class__.__name__,
)
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(
@@ -102,25 +94,21 @@ class QuarkW8A8Int8(QuarkScheme):
layer.register_parameter("weight_zero_point", weight_zero_point)
# INPUT SCALE
input_zero_point = None
input_scale = None
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
)
layer.register_parameter("input_scale", input_scale)
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
)
layer.register_parameter("input_zero_point", input_zero_point)
self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
w_s_param_name="weight_scale",
i_s_param_name="input_scale",
i_zp_param_name="input_zero_point",
azp_adj_param_name="azp_adj",
)
layer.register_parameter("input_scale", input_scale)
layer.register_parameter("input_zero_point", input_zero_point)
if not hasattr(layer, "azp_adj"):
layer.register_parameter("azp_adj", None)
# Checkpoints are serialized in quark format, which is
# different from the format the kernel may want. Handle repacking here.

View File

@@ -123,6 +123,9 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN)
kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True)
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)

View File

@@ -1,34 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from packaging import version
from vllm import _custom_ops as ops
from vllm import envs
from vllm.config import CompilationMode, get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from vllm.utils.platform_utils import get_cu_count
from vllm.utils.torch_utils import direct_register_custom_op
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
# The condition to determine if it is on a platform that supports
# torch._scaled_mm rowwise feature.
# The condition is determined once as the operations
# are time-consuming.
USE_ROWWISE_TORCH_SCALED_MM = (
current_platform.is_rocm()
and version.parse(torch.__version__) >= version.parse("2.7")
and current_platform.has_device_capability(94)
)
def sparse_cutlass_supported() -> bool:
@@ -140,361 +117,6 @@ def requantize_with_max_scale(
return max_w_scale, weight
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
def cutlass_w8a8_scaled_mm(
*,
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
**kwargs,
) -> torch.Tensor:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(
qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias
)
return output.view(*output_shape)
def flashinfer_w8a8_scaled_mm(
*,
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
**kwargs,
) -> torch.Tensor:
return flashinfer_scaled_fp8_mm(
qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias
)
def rocm_per_tensor_w8a8_scaled_mm_impl(
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
from vllm.platforms.rocm import on_mi3xx
if (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_mi3xx()
and qinput.shape[0] == 1
and qinput.shape[1] % 16 == 0
and ((bias is None) or (bias.dtype == out_dtype))
):
output = ops.wvSplitKQ(
weight.t(),
qinput,
out_dtype,
scale_a,
scale_b,
get_cu_count(),
bias,
)
else:
output = torch._scaled_mm(
qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias,
)
return output
def rocm_per_tensor_w8a8_scaled_mm_fake(
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), dtype=out_dtype)
def rocm_per_tensor_w8a8_scaled_mm(
*,
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
qinput, weight, out_dtype, scale_a, scale_b, bias
)
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
direct_register_custom_op(
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
)
def torch_per_tensor_w8a8_scaled_mm(
*,
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
output = torch._scaled_mm(
qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
def torch_per_token_w8a8_scaled_mm(
*,
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
**kwargs,
) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
#
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b.t(),
bias=bias,
)
output = torch.narrow(output, 0, 0, qinput.shape[0])
output = output.view(*output_shape)
return output
def torch_channelwise_w8a8_scaled_mm(
*,
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
**kwargs,
) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(
qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, qinput.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * scale_b.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
def dispatch_w8a8_scaled_mm(
preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool
) -> Callable[..., torch.Tensor]:
if per_tensor_weights and per_tensor_activations:
if preferred_backend == "rocm":
return rocm_per_tensor_w8a8_scaled_mm
if preferred_backend == "flashinfer":
return flashinfer_w8a8_scaled_mm
if preferred_backend == "cutlass":
return cutlass_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
return cutlass_w8a8_scaled_mm
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
if (
not per_tensor_weights
and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM
):
return torch_per_token_w8a8_scaled_mm
# Normally, torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
return torch_channelwise_w8a8_scaled_mm
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397
class Fp8LinearOp:
"""
This class executes a FP8 linear layer using cutlass if supported and
torch.scaled_mm otherwise.
It needs to be a class instead of a method so that config can be read
in the __init__ method, as reading config is not allowed inside forward.
"""
def __init__(
self,
act_quant_static: bool,
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
pad_output: bool | None = None,
):
if current_platform.is_rocm():
self.preferred_backend = "rocm"
elif current_platform.is_cuda() and cutlass_fp8_supported():
if has_flashinfer() and current_platform.has_device_capability(100):
self.preferred_backend = "flashinfer"
else:
self.preferred_backend = "cutlass"
else:
self.preferred_backend = "torch"
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
if pad_output is None:
config = get_current_vllm_config().compilation_config
pad_output = (
config.mode < CompilationMode.VLLM_COMPILE
and self.preferred_backend == "torch"
)
self.output_padding = 17 if pad_output else None
self.act_quant_static = act_quant_static
self.act_quant_group_shape = act_quant_group_shape
self.quant_fp8 = QuantFP8(
static=act_quant_static,
group_shape=act_quant_group_shape,
num_token_padding=self.output_padding,
)
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype | None = None,
input_scale: torch.Tensor | None = None,
input_scale_ub: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[1]]
if out_dtype is None:
out_dtype = input.dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
if input.dtype != current_platform.fp8_dtype():
qinput, x_scale = self.quant_fp8(
input_2d,
input_scale,
input_scale_ub,
)
else:
qinput, x_scale = input_2d, input_scale
# Must have dim() conditions
# In per-token quant scenario, when the number of token is 1,
# the scale will only have 1 elements.
# Without checking the dim(),
# we cannot distingushes between per-tensor and per-token quant.
# Example:
# When the number of token is 1, per-token scale is [[1]]
# When per-tensor scale is [1] or ().
per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
self.preferred_backend, per_tensor_weights, per_tensor_activations
)
return w8a8_scaled_mm_func(
qinput=qinput,
weight=weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
output_shape=output_shape,
)
def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,