[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)
Signed-off-by: maral <maralbahari.98@gmail.com> Signed-off-by: Maral <maralbahari.98@gmail.com>
This commit is contained in:
@@ -9,11 +9,12 @@ os.environ["VLLM_USE_DEEP_GEMM"] = "0"
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.benchmarks.lib.utils import default_vllm_config
|
from vllm.benchmarks.lib.utils import default_vllm_config
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.kernels.linear import (
|
||||||
W8A8BlockFp8LinearOp,
|
init_fp8_linear_kernel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape,
|
GroupShape,
|
||||||
|
create_fp8_quant_key,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||||
@@ -70,11 +71,15 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
|
|||||||
weight_group_shape = GroupShape(block_n, block_k)
|
weight_group_shape = GroupShape(block_n, block_k)
|
||||||
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
|
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
|
||||||
|
|
||||||
linear_op = W8A8BlockFp8LinearOp(
|
linear_op = init_fp8_linear_kernel(
|
||||||
weight_group_shape=weight_group_shape,
|
weight_quant_key=create_fp8_quant_key(
|
||||||
act_quant_group_shape=act_quant_group_shape,
|
static=True, group_shape=weight_group_shape
|
||||||
cutlass_block_fp8_supported=use_cutlass,
|
),
|
||||||
use_aiter_and_is_supported=False,
|
activation_quant_key=create_fp8_quant_key(
|
||||||
|
static=False, group_shape=act_quant_group_shape
|
||||||
|
),
|
||||||
|
out_dtype=torch.get_default_dtype(),
|
||||||
|
module_name="build_w8a8_block_fp8_runner",
|
||||||
)
|
)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|||||||
@@ -39,7 +39,9 @@ from vllm.utils.torch_utils import set_random_seed
|
|||||||
|
|
||||||
|
|
||||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
def __init__(
|
||||||
|
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
@@ -78,7 +80,9 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
|||||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||||
quant_key = kFp8StaticTensorSym
|
quant_key = kFp8StaticTensorSym
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
def __init__(
|
||||||
|
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
@@ -88,6 +92,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
|||||||
weight_shape=(hidden_size, hidden_size),
|
weight_shape=(hidden_size, hidden_size),
|
||||||
activation_quant_key=self.quant_key,
|
activation_quant_key=self.quant_key,
|
||||||
weight_quant_key=self.quant_key,
|
weight_quant_key=self.quant_key,
|
||||||
|
input_dtype=dtype,
|
||||||
)
|
)
|
||||||
for i in range(3)
|
for i in range(3)
|
||||||
]
|
]
|
||||||
@@ -127,7 +132,9 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
def __init__(
|
||||||
|
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
@@ -314,7 +321,7 @@ def all_reduce_fusion_pass_on_test_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
token_num = batch_size * seq_len
|
token_num = batch_size * seq_len
|
||||||
model = test_model_cls(hidden_size, token_num)
|
model = test_model_cls(hidden_size, token_num, dtype=dtype)
|
||||||
|
|
||||||
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
|
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||||
|
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
|||||||
weight_shape=(hidden_size, hidden_size),
|
weight_shape=(hidden_size, hidden_size),
|
||||||
activation_quant_key=self.quant_key,
|
activation_quant_key=self.quant_key,
|
||||||
weight_quant_key=self.quant_key,
|
weight_quant_key=self.quant_key,
|
||||||
|
input_dtype=self.vllm_config.model_config.dtype,
|
||||||
)
|
)
|
||||||
for i in range(3)
|
for i in range(3)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from vllm.config import (
|
|||||||
ModelConfig,
|
ModelConfig,
|
||||||
PassConfig,
|
PassConfig,
|
||||||
VllmConfig,
|
VllmConfig,
|
||||||
|
get_current_vllm_config,
|
||||||
set_current_vllm_config,
|
set_current_vllm_config,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
@@ -49,6 +50,7 @@ class TestSiluMul(torch.nn.Module):
|
|||||||
weight_shape=(hidden_size, hidden_size),
|
weight_shape=(hidden_size, hidden_size),
|
||||||
activation_quant_key=self.quant_key,
|
activation_quant_key=self.quant_key,
|
||||||
weight_quant_key=self.quant_key,
|
weight_quant_key=self.quant_key,
|
||||||
|
input_dtype=get_current_vllm_config().model_config.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@@ -92,6 +94,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
|||||||
weight_shape=(hidden_size, intermediate_size),
|
weight_shape=(hidden_size, intermediate_size),
|
||||||
activation_quant_key=self.quant_key,
|
activation_quant_key=self.quant_key,
|
||||||
weight_quant_key=self.quant_key,
|
weight_quant_key=self.quant_key,
|
||||||
|
input_dtype=get_current_vllm_config().model_config.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual):
|
def forward(self, hidden_states, residual):
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import vllm.config
|
|||||||
import vllm.ir.ops
|
import vllm.ir.ops
|
||||||
import vllm.plugins
|
import vllm.plugins
|
||||||
from tests.compile.backend import TestBackend
|
from tests.compile.backend import TestBackend
|
||||||
from tests.utils import TestBlockFP8Layer, TestFP8Layer
|
from tests.utils import TestFP8Layer
|
||||||
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
|
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
|
||||||
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
|
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
|
||||||
from vllm.compilation.passes.fusion.rms_quant_fusion import (
|
from vllm.compilation.passes.fusion.rms_quant_fusion import (
|
||||||
@@ -28,19 +28,23 @@ from vllm.config import (
|
|||||||
VllmConfig,
|
VllmConfig,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.kernels.linear import (
|
from vllm.model_executor.kernels.linear import (
|
||||||
|
AiterFp8BlockScaledMMKernel,
|
||||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||||
|
CutlassFp8BlockScaledMMKernel,
|
||||||
CutlassFP8ScaledMMLinearKernel,
|
CutlassFP8ScaledMMLinearKernel,
|
||||||
|
DeepGemmFp8BlockScaledMMKernel,
|
||||||
|
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
|
||||||
FlashInferFP8ScaledMMLinearKernel,
|
FlashInferFP8ScaledMMLinearKernel,
|
||||||
FP8ScaledMMLinearKernel,
|
|
||||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||||
ROCmFP8ScaledMMLinearKernel,
|
ROCmFP8ScaledMMLinearKernel,
|
||||||
RowWiseTorchFP8ScaledMMLinearKernel,
|
RowWiseTorchFP8ScaledMMLinearKernel,
|
||||||
|
TritonFp8BlockScaledMMKernel,
|
||||||
|
_KernelT,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape,
|
GroupShape,
|
||||||
QuantKey,
|
create_fp8_quant_key,
|
||||||
ScaleDesc,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
cutlass_block_fp8_supported,
|
cutlass_block_fp8_supported,
|
||||||
@@ -66,9 +70,12 @@ CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
|||||||
(PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
|
(PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
|
||||||
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
|
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
|
||||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
||||||
# Blockwise group shapes (no kernel abstraction)
|
# Blockwise group shapes
|
||||||
(None, GroupShape(1, 128)),
|
(FlashInferFp8DeepGEMMDynamicBlockScaledKernel, GroupShape(1, 128)),
|
||||||
(None, GroupShape(1, 64)),
|
(CutlassFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||||
|
(DeepGemmFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||||
|
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||||
|
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
|
||||||
]
|
]
|
||||||
|
|
||||||
# ROCm kernels
|
# ROCm kernels
|
||||||
@@ -80,8 +87,8 @@ ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
|||||||
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
|
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
|
||||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
||||||
# Blockwise group shapes (no kernel abstraction)
|
# Blockwise group shapes (no kernel abstraction)
|
||||||
(None, GroupShape(1, 128)),
|
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||||
(None, GroupShape(1, 64)),
|
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
|
||||||
]
|
]
|
||||||
|
|
||||||
KERNEL_GROUPSHAPE_COMBINATIONS = (
|
KERNEL_GROUPSHAPE_COMBINATIONS = (
|
||||||
@@ -100,8 +107,8 @@ AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
|||||||
# Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
|
# Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
|
||||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
|
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
|
||||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
|
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
|
||||||
# Blockwise (no kernel abstraction)
|
# Blockwise
|
||||||
(None, GroupShape(1, 128), True),
|
(AiterFp8BlockScaledMMKernel, GroupShape(1, 128), True),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -110,8 +117,9 @@ class TestModel(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
eps: float,
|
eps: float,
|
||||||
force_kernel: FP8ScaledMMLinearKernel | None,
|
force_kernel: type[_KernelT] | None,
|
||||||
group_shape: GroupShape,
|
group_shape: GroupShape,
|
||||||
|
dtype: torch.dtype,
|
||||||
use_aiter_fusion: bool = False,
|
use_aiter_fusion: bool = False,
|
||||||
use_aiter_quant: bool = False,
|
use_aiter_quant: bool = False,
|
||||||
*args,
|
*args,
|
||||||
@@ -129,54 +137,42 @@ class TestModel(torch.nn.Module):
|
|||||||
is_blockwise = group_shape.is_per_group()
|
is_blockwise = group_shape.is_per_group()
|
||||||
|
|
||||||
if is_blockwise:
|
if is_blockwise:
|
||||||
act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
|
block_size = group_shape.col
|
||||||
self.activation_quant_key = QuantKey(
|
self.activation_quant_key = create_fp8_quant_key(
|
||||||
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
|
static=False, group_shape=group_shape
|
||||||
)
|
)
|
||||||
self.fp8_linear_layers = [
|
self.weight_quant_key = create_fp8_quant_key(
|
||||||
TestBlockFP8Layer(
|
static=True, group_shape=GroupShape(block_size, block_size)
|
||||||
weight_shape=(hidden_size, hidden_size),
|
|
||||||
group_shape=group_shape,
|
|
||||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
|
|
||||||
use_aiter_and_is_supported=use_aiter_quant,
|
|
||||||
transpose_weights=use_aiter_fusion,
|
|
||||||
)
|
|
||||||
for _ in range(3)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.enable_quant_fp8_custom_op = (
|
|
||||||
False
|
|
||||||
if use_aiter_quant
|
|
||||||
else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
is_static = group_shape == GroupShape.PER_TENSOR
|
is_static = group_shape == GroupShape.PER_TENSOR
|
||||||
act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
|
self.activation_quant_key = create_fp8_quant_key(
|
||||||
w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
|
is_static, group_shape=group_shape
|
||||||
self.activation_quant_key = QuantKey(
|
|
||||||
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
|
|
||||||
)
|
)
|
||||||
self.weight_quant_key = QuantKey(
|
self.weight_quant_key = create_fp8_quant_key(
|
||||||
dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
|
static=True, group_shape=group_shape
|
||||||
)
|
)
|
||||||
self.fp8_linear_layers = [
|
|
||||||
TestFP8Layer(
|
|
||||||
weight_shape=(hidden_size, hidden_size),
|
|
||||||
activation_quant_key=self.activation_quant_key,
|
|
||||||
weight_quant_key=self.weight_quant_key,
|
|
||||||
force_kernel=force_kernel,
|
|
||||||
)
|
|
||||||
for _ in range(3)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Enable aiter quantization if requested
|
self.fp8_linear_layers = [
|
||||||
for layer in self.fp8_linear_layers:
|
TestFP8Layer(
|
||||||
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
|
weight_shape=(hidden_size, hidden_size),
|
||||||
|
activation_quant_key=self.activation_quant_key,
|
||||||
|
weight_quant_key=self.weight_quant_key,
|
||||||
|
force_kernel=force_kernel,
|
||||||
|
transpose_weights=use_aiter_fusion,
|
||||||
|
input_dtype=dtype,
|
||||||
|
)
|
||||||
|
for _ in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
|
# Enable aiter quantization if requested
|
||||||
0
|
for layer in self.fp8_linear_layers:
|
||||||
].is_quant_fp8_enabled()
|
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
|
||||||
|
|
||||||
|
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
|
||||||
|
0
|
||||||
|
].is_quant_fp8_enabled()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# avoid having graph input be an arg to a pattern directly
|
# avoid having graph input be an arg to a pattern directly
|
||||||
@@ -354,6 +350,7 @@ def test_fusion_rmsnorm_quant(
|
|||||||
eps=eps,
|
eps=eps,
|
||||||
force_kernel=force_kernel,
|
force_kernel=force_kernel,
|
||||||
group_shape=group_shape,
|
group_shape=group_shape,
|
||||||
|
dtype=dtype,
|
||||||
use_aiter_fusion=False,
|
use_aiter_fusion=False,
|
||||||
use_aiter_quant=False,
|
use_aiter_quant=False,
|
||||||
)
|
)
|
||||||
@@ -426,6 +423,7 @@ def test_aiter_fusion_rmsnorm_quant(
|
|||||||
eps=eps,
|
eps=eps,
|
||||||
force_kernel=force_kernel,
|
force_kernel=force_kernel,
|
||||||
group_shape=group_shape,
|
group_shape=group_shape,
|
||||||
|
dtype=dtype,
|
||||||
use_aiter_fusion=True, # Always use aiter fusion ops in aiter test
|
use_aiter_fusion=True, # Always use aiter fusion ops in aiter test
|
||||||
use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization
|
use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
|||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
self.dtype = vllm_config.model_config.dtype
|
||||||
|
|
||||||
self.attn = Attention(
|
self.attn = Attention(
|
||||||
num_heads=self.num_qo_heads,
|
num_heads=self.num_qo_heads,
|
||||||
@@ -155,6 +156,7 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
|||||||
activation_quant_key=self.quant_key,
|
activation_quant_key=self.quant_key,
|
||||||
weight_quant_key=self.quant_key,
|
weight_quant_key=self.quant_key,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
input_dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
w = kwargs.get("w")
|
w = kwargs.get("w")
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ class MLAAttentionQuantPatternModel(torch.nn.Module):
|
|||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
self.dtype = vllm_config.model_config.dtype
|
||||||
|
|
||||||
# Create kv_b_proj (ColumnParallelLinear) on device.
|
# Create kv_b_proj (ColumnParallelLinear) on device.
|
||||||
# Reuse weights from prior model instance when available, because
|
# Reuse weights from prior model instance when available, because
|
||||||
@@ -190,6 +191,7 @@ class TestMLAAttentionFp8StaticQuantPatternModel(MLAAttentionQuantPatternModel):
|
|||||||
activation_quant_key=self.quant_key,
|
activation_quant_key=self.quant_key,
|
||||||
weight_quant_key=self.quant_key,
|
weight_quant_key=self.quant_key,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
input_dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
w = kwargs.get("w")
|
w = kwargs.get("w")
|
||||||
|
|||||||
@@ -36,9 +36,9 @@ from vllm.model_executor.kernels.linear import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape,
|
GroupShape,
|
||||||
|
create_fp8_quant_key,
|
||||||
kFp8Dynamic128Sym,
|
kFp8Dynamic128Sym,
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
kNvfp4Dynamic,
|
kNvfp4Dynamic,
|
||||||
@@ -58,7 +58,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
|||||||
quant_key = kFp8StaticTensorSym
|
quant_key = kFp8StaticTensorSym
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hidden_size: int, force_kernel: FP8ScaledMMLinearKernel, **kwargs
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
force_kernel: FP8ScaledMMLinearKernel,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.silu_and_mul = SiluAndMul()
|
self.silu_and_mul = SiluAndMul()
|
||||||
@@ -68,6 +72,7 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
|||||||
activation_quant_key=self.quant_key,
|
activation_quant_key=self.quant_key,
|
||||||
weight_quant_key=self.quant_key,
|
weight_quant_key=self.quant_key,
|
||||||
force_kernel=force_kernel,
|
force_kernel=force_kernel,
|
||||||
|
input_dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||||
@@ -137,14 +142,20 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
|
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
|
||||||
def __init__(self, hidden_size: int, **kwargs):
|
act_quant_key = kFp8Dynamic128Sym
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int, dtype: torch.dtype, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.silu_and_mul = SiluAndMul()
|
self.silu_and_mul = SiluAndMul()
|
||||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
self.weight_quant_key = create_fp8_quant_key(
|
||||||
weight_group_shape=GroupShape(128, 128),
|
static=True, group_shape=GroupShape(hidden_size, hidden_size)
|
||||||
act_quant_group_shape=GroupShape(1, 128),
|
)
|
||||||
cutlass_block_fp8_supported=False,
|
|
||||||
use_aiter_and_is_supported=True,
|
self.w8a8_block_fp8_linear = TestFP8Layer(
|
||||||
|
weight_shape=(hidden_size, hidden_size),
|
||||||
|
weight_quant_key=self.weight_quant_key,
|
||||||
|
activation_quant_key=self.act_quant_key,
|
||||||
|
input_dtype=dtype,
|
||||||
)
|
)
|
||||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||||
|
|
||||||
@@ -157,7 +168,7 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = self.silu_and_mul(x)
|
y = self.silu_and_mul(x)
|
||||||
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
|
x2 = self.w8a8_block_fp8_linear(y, self.w, self.wscale)
|
||||||
return x2
|
return x2
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
@@ -324,7 +335,9 @@ def test_fusion_silu_and_mul_quant(
|
|||||||
|
|
||||||
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
|
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
|
||||||
backend = TestBackend(*passes)
|
backend = TestBackend(*passes)
|
||||||
model = model_class(hidden_size=hidden_size, force_kernel=force_kernel, x=x)
|
model = model_class(
|
||||||
|
hidden_size=hidden_size, force_kernel=force_kernel, x=x, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
# First dimension dynamic
|
# First dimension dynamic
|
||||||
torch._dynamo.mark_dynamic(x, 0)
|
torch._dynamo.mark_dynamic(x, 0)
|
||||||
|
|||||||
@@ -246,8 +246,9 @@ def default_vllm_config():
|
|||||||
"""
|
"""
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
|
||||||
with set_current_vllm_config(VllmConfig()):
|
config = VllmConfig()
|
||||||
yield
|
with set_current_vllm_config(config):
|
||||||
|
yield config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ from tests.kernels.quant_utils import (
|
|||||||
native_w8a8_block_matmul,
|
native_w8a8_block_matmul,
|
||||||
)
|
)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import cutlass_scaled_mm
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
cutlass_scaled_mm,
|
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
w8a8_triton_block_scaled_mm,
|
w8a8_triton_block_scaled_mm,
|
||||||
)
|
)
|
||||||
@@ -202,7 +202,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
|
|
||||||
# only aligned sizes are supported by deepgemm
|
# only aligned sizes are supported by deepgemm
|
||||||
if not should_use_deepgemm_for_fp8_linear(
|
if not should_use_deepgemm_for_fp8_linear(
|
||||||
output_dtype=out_dtype, weight=B_fp32, supports_deep_gemm=True
|
output_dtype=out_dtype, weight_shape=B_fp32.shape, supports_deep_gemm=True
|
||||||
):
|
):
|
||||||
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,9 @@ from compressed_tensors.quantization import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from tests.models.utils import check_logprobs_close
|
from tests.models.utils import check_logprobs_close
|
||||||
|
from vllm.model_executor.kernels.linear import (
|
||||||
|
Fp8BlockScaledMMLinearKernel,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod
|
from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||||
CompressedTensorsConfig,
|
CompressedTensorsConfig,
|
||||||
@@ -29,7 +32,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|||||||
CompressedTensorsWNA16,
|
CompressedTensorsWNA16,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
|
||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||||
cutlass_fp4_supported,
|
cutlass_fp4_supported,
|
||||||
)
|
)
|
||||||
@@ -473,16 +475,14 @@ def test_compressed_tensors_fp8_block_enabled(vllm_runner):
|
|||||||
qkv_proj = layer.self_attn.qkv_proj
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
|
||||||
assert isinstance(
|
assert isinstance(qkv_proj.scheme.fp8_linear, Fp8BlockScaledMMLinearKernel)
|
||||||
qkv_proj.scheme.w8a8_block_fp8_linear, W8A8BlockFp8LinearOp
|
|
||||||
)
|
|
||||||
|
|
||||||
assert qkv_proj.weight.dtype is fp8_dtype
|
assert qkv_proj.weight.dtype is fp8_dtype
|
||||||
assert qkv_proj.weight_scale.dtype is torch.float32
|
assert qkv_proj.weight_scale.dtype is torch.float32
|
||||||
assert len(qkv_proj.weight.shape) == 2
|
assert len(qkv_proj.weight.shape) == 2
|
||||||
assert len(qkv_proj.weight_scale.shape) == 2
|
assert len(qkv_proj.weight_scale.shape) == 2
|
||||||
|
|
||||||
input_quant_op = qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
|
input_quant_op = qkv_proj.scheme.fp8_linear.quant_fp8
|
||||||
assert isinstance(input_quant_op, QuantFP8)
|
assert isinstance(input_quant_op, QuantFP8)
|
||||||
assert input_quant_op._forward_method in (
|
assert input_quant_op._forward_method in (
|
||||||
input_quant_op.forward_cuda,
|
input_quant_op.forward_cuda,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import torch
|
|||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.config.model import ModelConfig
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.quantization.fp8 import (
|
from vllm.model_executor.layers.quantization.fp8 import (
|
||||||
Fp8Config,
|
Fp8Config,
|
||||||
@@ -406,6 +407,8 @@ def test_fp8_reloading(
|
|||||||
"If this is your use case, consider using a restore function like #26327"
|
"If this is your use case, consider using a restore function like #26327"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set model config as model_config.dtype is required in Fp8LinearMethod.
|
||||||
|
default_vllm_config.model_config = ModelConfig()
|
||||||
with torch.device("cuda:0"):
|
with torch.device("cuda:0"):
|
||||||
config = Fp8Config(
|
config = Fp8Config(
|
||||||
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
|
from vllm.config.model import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
@@ -46,7 +47,7 @@ def _snapshot_download_or_skip(model_id: str) -> str:
|
|||||||
not is_quant_method_supported("modelopt"),
|
not is_quant_method_supported("modelopt"),
|
||||||
reason="ModelOpt FP8 is not supported on this GPU type.",
|
reason="ModelOpt FP8 is not supported on this GPU type.",
|
||||||
)
|
)
|
||||||
def test_modelopt_fp8_checkpoint_setup(vllm_runner):
|
def test_modelopt_fp8_checkpoint_setup(default_vllm_config, vllm_runner):
|
||||||
"""Test ModelOpt FP8 checkpoint loading and structure validation."""
|
"""Test ModelOpt FP8 checkpoint loading and structure validation."""
|
||||||
# TODO: provide a small publicly available test checkpoint
|
# TODO: provide a small publicly available test checkpoint
|
||||||
model_path = (
|
model_path = (
|
||||||
@@ -61,6 +62,8 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
|
|||||||
"This test requires a local ModelOpt FP8 checkpoint."
|
"This test requires a local ModelOpt FP8 checkpoint."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
|
||||||
|
default_vllm_config.model_config = ModelConfig()
|
||||||
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
|
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
|
||||||
|
|
||||||
def check_model(model):
|
def check_model(model):
|
||||||
@@ -120,11 +123,13 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
|
|||||||
not is_quant_method_supported("modelopt"),
|
not is_quant_method_supported("modelopt"),
|
||||||
reason="ModelOpt FP8 is not supported on this GPU type.",
|
reason="ModelOpt FP8 is not supported on this GPU type.",
|
||||||
)
|
)
|
||||||
def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner):
|
def test_modelopt_fp8_pc_pt_checkpoint_setup(default_vllm_config, vllm_runner):
|
||||||
"""Test ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoint setup."""
|
"""Test ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoint setup."""
|
||||||
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pc-pt"
|
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pc-pt"
|
||||||
model_path = _snapshot_download_or_skip(model_id)
|
model_path = _snapshot_download_or_skip(model_id)
|
||||||
|
|
||||||
|
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
|
||||||
|
default_vllm_config.model_config = ModelConfig()
|
||||||
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
|
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
|
||||||
|
|
||||||
def check_model(model):
|
def check_model(model):
|
||||||
@@ -181,11 +186,13 @@ def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner):
|
|||||||
not is_quant_method_supported("modelopt"),
|
not is_quant_method_supported("modelopt"),
|
||||||
reason="ModelOpt FP8 is not supported on this GPU type.",
|
reason="ModelOpt FP8 is not supported on this GPU type.",
|
||||||
)
|
)
|
||||||
def test_modelopt_fp8_pb_wo_checkpoint_setup(vllm_runner):
|
def test_modelopt_fp8_pb_wo_checkpoint_setup(default_vllm_config, vllm_runner):
|
||||||
"""Test ModelOpt FP8_PB_WO checkpoint setup."""
|
"""Test ModelOpt FP8_PB_WO checkpoint setup."""
|
||||||
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pb-wo"
|
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pb-wo"
|
||||||
model_path = _snapshot_download_or_skip(model_id)
|
model_path = _snapshot_download_or_skip(model_id)
|
||||||
|
|
||||||
|
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
|
||||||
|
default_vllm_config.model_config = ModelConfig()
|
||||||
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
|
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
|
||||||
|
|
||||||
def check_model(model):
|
def check_model(model):
|
||||||
|
|||||||
113
tests/utils.py
113
tests/utils.py
@@ -43,12 +43,10 @@ from vllm.distributed import (
|
|||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.entrypoints.cli.serve import ServeSubcommand
|
from vllm.entrypoints.cli.serve import ServeSubcommand
|
||||||
from vllm.model_executor.kernels.linear import (
|
from vllm.model_executor.kernels.linear import (
|
||||||
FP8ScaledMMLinearKernel,
|
_KernelT,
|
||||||
init_fp8_linear_kernel,
|
init_fp8_linear_kernel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape,
|
|
||||||
QuantKey,
|
QuantKey,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
@@ -1811,31 +1809,52 @@ class TestFP8Layer(torch.nn.Module):
|
|||||||
weight_shape: tuple[int, int],
|
weight_shape: tuple[int, int],
|
||||||
activation_quant_key: QuantKey,
|
activation_quant_key: QuantKey,
|
||||||
weight_quant_key: QuantKey,
|
weight_quant_key: QuantKey,
|
||||||
|
input_dtype: torch.dtype,
|
||||||
out_dtype: torch.dtype | None = None,
|
out_dtype: torch.dtype | None = None,
|
||||||
|
transpose_weights: bool = False,
|
||||||
device: torch.device | None = None,
|
device: torch.device | None = None,
|
||||||
force_kernel: FP8ScaledMMLinearKernel | None = None,
|
force_kernel: type[_KernelT] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
per_tensor_weights = weight_quant_key.scale.group_shape.is_per_tensor()
|
act_scale_desc = activation_quant_key.scale
|
||||||
is_static_activation_scale = activation_quant_key.scale.static
|
weight_scale_desc = weight_quant_key.scale
|
||||||
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
|
is_block_wise = act_scale_desc.group_shape.is_per_group()
|
||||||
|
if is_block_wise:
|
||||||
self.weight_scale = torch.rand(
|
block_size = weight_scale_desc.group_shape.col
|
||||||
weight_scale_shape, dtype=torch.float32, device=device
|
weight_scale_shape = weight_shape[0] // block_size
|
||||||
)
|
self.weight_scale_inv = torch.rand(
|
||||||
self.input_scale = (
|
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
|
||||||
torch.rand(1, dtype=torch.float32, device=device)
|
)
|
||||||
if is_static_activation_scale
|
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
|
||||||
else None
|
self.input_scale = None
|
||||||
)
|
self.weight_scale = None
|
||||||
self.weight = torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
|
if transpose_weights:
|
||||||
self.input_scale_ub = None
|
self.weight = self.weight.t()
|
||||||
|
else:
|
||||||
|
per_tensor_weights = weight_scale_desc.group_shape.is_per_tensor()
|
||||||
|
is_static_activation_scale = act_scale_desc.static
|
||||||
|
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
|
||||||
|
self.weight_scale_inv = None
|
||||||
|
self.weight_scale = torch.rand(
|
||||||
|
weight_scale_shape, dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
self.input_scale = (
|
||||||
|
torch.rand(1, dtype=torch.float32, device=device)
|
||||||
|
if is_static_activation_scale
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.weight = (
|
||||||
|
torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
|
||||||
|
)
|
||||||
|
self.input_scale_ub = None
|
||||||
|
|
||||||
out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype
|
out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype
|
||||||
|
|
||||||
self.kernel = init_fp8_linear_kernel(
|
self.kernel = init_fp8_linear_kernel(
|
||||||
activation_quant_key=activation_quant_key,
|
activation_quant_key=activation_quant_key,
|
||||||
weight_quant_key=weight_quant_key,
|
weight_quant_key=weight_quant_key,
|
||||||
|
weight_shape=weight_shape,
|
||||||
|
input_dtype=input_dtype,
|
||||||
out_dtype=out_dtype,
|
out_dtype=out_dtype,
|
||||||
force_kernel=force_kernel,
|
force_kernel=force_kernel,
|
||||||
)
|
)
|
||||||
@@ -1847,61 +1866,3 @@ class TestFP8Layer(torch.nn.Module):
|
|||||||
self, y: torch.Tensor, bias: torch.Tensor | None = None
|
self, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.kernel.apply_weights(self, y, bias)
|
return self.kernel.apply_weights(self, y, bias)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Drop TestBlockFP8Layer in favour of a unified TestFP8Layer
|
|
||||||
# after refactoring W8A8BlockFp8LinearOp.
|
|
||||||
# https://github.com/vllm-project/vllm/issues/31818
|
|
||||||
class TestBlockFP8Layer:
|
|
||||||
"""
|
|
||||||
Test helper for blockwise FP8 linear operations. Creates random weights
|
|
||||||
and scales for W8A8BlockFp8LinearOp.
|
|
||||||
|
|
||||||
This is a workaround until W8A8BlockFp8LinearOp implements the kernel
|
|
||||||
abstraction (ScaledMMLinearKernel) for blockwise quantization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
weight_shape: Shape of the weight tensor (out_features, in_features).
|
|
||||||
group_shape: Blockwise quantization group shape.
|
|
||||||
cutlass_block_fp8_supported: Whether CUTLASS blockwise FP8 is available.
|
|
||||||
use_aiter_and_is_supported: Whether to use aiter quantization ops.
|
|
||||||
transpose_weights: Whether to transpose weights after creation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
weight_shape: tuple[int, int],
|
|
||||||
group_shape: GroupShape,
|
|
||||||
cutlass_block_fp8_supported: bool = False,
|
|
||||||
use_aiter_and_is_supported: bool = False,
|
|
||||||
transpose_weights: bool = False,
|
|
||||||
):
|
|
||||||
weight_scale_shape = weight_shape[0] // group_shape[1]
|
|
||||||
self.weight_scale = torch.rand(
|
|
||||||
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
|
|
||||||
)
|
|
||||||
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
|
|
||||||
self.input_scale = None
|
|
||||||
if transpose_weights:
|
|
||||||
self.weight = self.weight.t()
|
|
||||||
|
|
||||||
self.linear_op = W8A8BlockFp8LinearOp(
|
|
||||||
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
|
|
||||||
act_quant_group_shape=group_shape,
|
|
||||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
|
|
||||||
use_aiter_and_is_supported=use_aiter_and_is_supported,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, y: torch.Tensor, bias: torch.Tensor | None = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return self.linear_op.apply(
|
|
||||||
input=y,
|
|
||||||
weight=self.weight,
|
|
||||||
weight_scale=self.weight_scale,
|
|
||||||
input_scale=self.input_scale,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_quant_fp8_enabled(self) -> bool:
|
|
||||||
return self.linear_op.input_quant_op.enabled()
|
|
||||||
|
|||||||
@@ -1002,11 +1002,11 @@ class VllmBackend:
|
|||||||
)
|
)
|
||||||
hash_content = []
|
hash_content = []
|
||||||
for filepath in forward_code_files:
|
for filepath in forward_code_files:
|
||||||
hash_content.append(filepath)
|
|
||||||
if filepath == "<string>":
|
if filepath == "<string>":
|
||||||
# This means the function was dynamically generated, with
|
# This means the function was dynamically generated, with
|
||||||
# e.g. exec(). We can't actually check these.
|
# e.g. exec(). We can't actually check these.
|
||||||
continue
|
continue
|
||||||
|
hash_content.append(filepath)
|
||||||
try:
|
try:
|
||||||
with open(filepath) as f:
|
with open(filepath) as f:
|
||||||
hash_content.append(f.read())
|
hash_content.append(f.read())
|
||||||
|
|||||||
@@ -19,6 +19,10 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.kernels.linear.base import (
|
||||||
|
MMLinearKernel,
|
||||||
|
MMLinearLayerConfig,
|
||||||
|
)
|
||||||
from vllm.model_executor.kernels.linear.mixed_precision import (
|
from vllm.model_executor.kernels.linear.mixed_precision import (
|
||||||
MPLinearKernel,
|
MPLinearKernel,
|
||||||
MPLinearLayerConfig,
|
MPLinearLayerConfig,
|
||||||
@@ -52,24 +56,30 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
|
|||||||
XPUwNa16LinearKernel,
|
XPUwNa16LinearKernel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.kernels.linear.scaled_mm import (
|
from vllm.model_executor.kernels.linear.scaled_mm import (
|
||||||
|
Fp8BlockScaledMMLinearKernel,
|
||||||
FP8ScaledMMLinearKernel,
|
FP8ScaledMMLinearKernel,
|
||||||
FP8ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
Int8ScaledMMLinearKernel,
|
Int8ScaledMMLinearKernel,
|
||||||
Int8ScaledMMLinearLayerConfig,
|
Int8ScaledMMLinearLayerConfig,
|
||||||
ScaledMMLinearKernel,
|
ScaledMMLinearKernel,
|
||||||
ScaledMMLinearLayerConfig,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
|
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
|
||||||
|
AiterFp8BlockScaledMMKernel,
|
||||||
AiterInt8ScaledMMLinearKernel,
|
AiterInt8ScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
|
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
|
||||||
CPUInt8ScaledMMLinearKernel,
|
CPUInt8ScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
|
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
|
||||||
|
CutlassFp8BlockScaledMMKernel,
|
||||||
CutlassFP8ScaledMMLinearKernel,
|
CutlassFP8ScaledMMLinearKernel,
|
||||||
CutlassInt8ScaledMMLinearKernel,
|
CutlassInt8ScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.kernels.linear.scaled_mm.deep_gemm import (
|
||||||
|
DeepGemmFp8BlockScaledMMKernel,
|
||||||
|
)
|
||||||
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
||||||
|
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
|
||||||
FlashInferFP8ScaledMMLinearKernel,
|
FlashInferFP8ScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.kernels.linear.scaled_mm.marlin import (
|
from vllm.model_executor.kernels.linear.scaled_mm.marlin import (
|
||||||
@@ -84,6 +94,7 @@ from vllm.model_executor.kernels.linear.scaled_mm.rocm import (
|
|||||||
ROCmFP8ScaledMMLinearKernel,
|
ROCmFP8ScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
|
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
|
||||||
|
TritonFp8BlockScaledMMKernel,
|
||||||
TritonInt8ScaledMMLinearKernel,
|
TritonInt8ScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.kernels.linear.scaled_mm.xpu import (
|
from vllm.model_executor.kernels.linear.scaled_mm.xpu import (
|
||||||
@@ -128,6 +139,23 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# in priority/performance order (when available)
|
||||||
|
_POSSIBLE_FP8_BLOCK_KERNELS: dict[
|
||||||
|
PlatformEnum, list[type[Fp8BlockScaledMMLinearKernel]]
|
||||||
|
] = {
|
||||||
|
PlatformEnum.CUDA: [
|
||||||
|
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
|
||||||
|
DeepGemmFp8BlockScaledMMKernel,
|
||||||
|
CutlassFp8BlockScaledMMKernel,
|
||||||
|
TritonFp8BlockScaledMMKernel,
|
||||||
|
],
|
||||||
|
PlatformEnum.ROCM: [
|
||||||
|
AiterFp8BlockScaledMMKernel,
|
||||||
|
TritonFp8BlockScaledMMKernel,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
# in priority/performance order (when available)
|
# in priority/performance order (when available)
|
||||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
||||||
PlatformEnum.CUDA: [
|
PlatformEnum.CUDA: [
|
||||||
@@ -152,8 +180,10 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
|
# TODO make all kernels inherit from MMLinearKernel
|
||||||
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
|
# then bound _KernelT only to MMLinearKernel
|
||||||
|
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel | MMLinearKernel)
|
||||||
|
_KernelConfigT = TypeVar("_KernelConfigT", bound=MMLinearLayerConfig)
|
||||||
|
|
||||||
|
|
||||||
def is_supported_and_can_implement_kernel(
|
def is_supported_and_can_implement_kernel(
|
||||||
@@ -243,32 +273,61 @@ def choose_scaled_mm_linear_kernel(
|
|||||||
def init_fp8_linear_kernel(
|
def init_fp8_linear_kernel(
|
||||||
activation_quant_key: QuantKey,
|
activation_quant_key: QuantKey,
|
||||||
weight_quant_key: QuantKey,
|
weight_quant_key: QuantKey,
|
||||||
|
weight_shape: tuple[int, int],
|
||||||
|
input_dtype: torch.dtype,
|
||||||
out_dtype: torch.dtype,
|
out_dtype: torch.dtype,
|
||||||
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
|
force_kernel: type[_KernelT] | None = None,
|
||||||
module_name: str | None = None,
|
module_name: str | None = None,
|
||||||
) -> FP8ScaledMMLinearKernel:
|
) -> FP8ScaledMMLinearKernel | Fp8BlockScaledMMLinearKernel:
|
||||||
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
||||||
weight_quant_key=weight_quant_key,
|
weight_quant_key=weight_quant_key,
|
||||||
activation_quant_key=activation_quant_key,
|
activation_quant_key=activation_quant_key,
|
||||||
|
weight_shape=weight_shape,
|
||||||
|
input_dtype=input_dtype,
|
||||||
out_dtype=out_dtype,
|
out_dtype=out_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
kernel_type = choose_scaled_mm_linear_kernel(
|
if activation_quant_key.scale.group_shape.is_per_group():
|
||||||
scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel
|
kernel_type = choose_scaled_mm_linear_kernel(
|
||||||
)
|
config=scaled_mm_linear_kernel_config,
|
||||||
|
possible_kernels=_POSSIBLE_FP8_BLOCK_KERNELS, # type: ignore[misc]
|
||||||
|
force_kernel=force_kernel,
|
||||||
|
)
|
||||||
|
if module_name:
|
||||||
|
logger.info_once(
|
||||||
|
"Selected %s for %s",
|
||||||
|
kernel_type.__name__,
|
||||||
|
module_name,
|
||||||
|
scope="global",
|
||||||
|
)
|
||||||
|
|
||||||
if module_name:
|
return kernel_type(
|
||||||
logger.info_once(
|
scaled_mm_linear_kernel_config,
|
||||||
"Selected %s for %s",
|
|
||||||
kernel_type.__name__,
|
|
||||||
module_name,
|
|
||||||
scope="global",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return kernel_type(
|
else:
|
||||||
scaled_mm_linear_kernel_config,
|
kernel_type = choose_scaled_mm_linear_kernel(
|
||||||
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
|
config=scaled_mm_linear_kernel_config,
|
||||||
)
|
possible_kernels=_POSSIBLE_FP8_KERNELS, # type: ignore[misc]
|
||||||
|
force_kernel=force_kernel,
|
||||||
|
)
|
||||||
|
if module_name:
|
||||||
|
logger.info_once(
|
||||||
|
"Selected %s for %s",
|
||||||
|
kernel_type.__name__,
|
||||||
|
module_name,
|
||||||
|
scope="global",
|
||||||
|
)
|
||||||
|
|
||||||
|
return kernel_type(
|
||||||
|
scaled_mm_linear_kernel_config,
|
||||||
|
layer_param_names=[
|
||||||
|
"weight",
|
||||||
|
"weight_scale",
|
||||||
|
"input_scale",
|
||||||
|
"input_scale_ub",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def init_int8_linear_kernel(
|
def init_int8_linear_kernel(
|
||||||
@@ -433,4 +492,7 @@ __all__ = [
|
|||||||
"MarlinLinearKernel",
|
"MarlinLinearKernel",
|
||||||
"XPUW4A8IntLinearKernel",
|
"XPUW4A8IntLinearKernel",
|
||||||
"XPUwNa16LinearKernel",
|
"XPUwNa16LinearKernel",
|
||||||
|
"_KernelT",
|
||||||
|
"DeepGemmFp8BlockScaledMMKernel",
|
||||||
|
"FlashInferFp8DeepGEMMDynamicBlockScaledKernel",
|
||||||
]
|
]
|
||||||
|
|||||||
324
vllm/model_executor/kernels/linear/base.py
Normal file
324
vllm/model_executor/kernels/linear/base.py
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, ClassVar, Generic, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MMLinearLayerConfig: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Params:
|
||||||
|
"""Base class for quantized layer parameters.
|
||||||
|
|
||||||
|
This class provides a typed interface for accessing quantized weights and scales
|
||||||
|
from layer modules. It serves as a parameter container that can be extracted from
|
||||||
|
layers and passed to kernel implementations.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
weight: The quantized weight tensor
|
||||||
|
weight_scale: weight scaling factors
|
||||||
|
input_scale: Optional input scaling factors
|
||||||
|
|
||||||
|
Class Variables:
|
||||||
|
WEIGHT: Attribute name for weight tensor on the layer module
|
||||||
|
WEIGHT_SCALE: Attribute name for weight scale tensor on the layer module
|
||||||
|
INPUT_SCALE: Attribute name for input scale tensor on the layer module
|
||||||
|
|
||||||
|
Important:
|
||||||
|
The string values of WEIGHT, WEIGHT_SCALE, and INPUT_SCALE class variables
|
||||||
|
MUST match the attribute names used in the corresponding quantization method's
|
||||||
|
create_weights() implementation.
|
||||||
|
For example, if FP8LinearMethod.create_weights()
|
||||||
|
sets layer.weight and layer.weight_scale,
|
||||||
|
then WEIGHT="weight" and
|
||||||
|
WEIGHT_SCALE="weight_scale" must be used here.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```python
|
||||||
|
# Extract parameters from a quantized layer
|
||||||
|
params = Params.from_layer(layer)
|
||||||
|
|
||||||
|
# Access typed parameters
|
||||||
|
output = func(input, params.weight, params.weight_scale)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
weight: torch.Tensor
|
||||||
|
weight_scale: torch.Tensor
|
||||||
|
input_scale: torch.Tensor | None
|
||||||
|
|
||||||
|
# Attribute names on the layer
|
||||||
|
WEIGHT: ClassVar[str] = "weight"
|
||||||
|
WEIGHT_SCALE: ClassVar[str] = "weight_scale"
|
||||||
|
INPUT_SCALE: ClassVar[str] = "input_scale"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_layer(cls, layer: torch.nn.Module) -> Self:
|
||||||
|
return cls(
|
||||||
|
weight=getattr(layer, cls.WEIGHT),
|
||||||
|
weight_scale=getattr(layer, cls.WEIGHT_SCALE),
|
||||||
|
input_scale=getattr(layer, cls.INPUT_SCALE, None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FP8Params(Params):
|
||||||
|
"""FP8 layer parameters with typed fields"""
|
||||||
|
|
||||||
|
input_scale_ub: torch.Tensor | None
|
||||||
|
|
||||||
|
INPUT_SCALE_UB: ClassVar[str] = "input_scale_ub"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_layer(cls, layer: torch.nn.Module) -> "FP8Params":
|
||||||
|
"""Extract parameters from layer"""
|
||||||
|
return cls(
|
||||||
|
weight=getattr(layer, cls.WEIGHT),
|
||||||
|
weight_scale=getattr(layer, cls.WEIGHT_SCALE),
|
||||||
|
input_scale=getattr(layer, cls.INPUT_SCALE, None),
|
||||||
|
input_scale_ub=getattr(layer, cls.INPUT_SCALE_UB, None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Int8Params(Params):
|
||||||
|
"""Int8 layer parameters with typed fields"""
|
||||||
|
|
||||||
|
input_zero_point: torch.Tensor | None
|
||||||
|
azp_adj: torch.Tensor | None
|
||||||
|
|
||||||
|
INPUT_ZERO_POINT: ClassVar[str] = "input_zero_point"
|
||||||
|
AZP_ADJ: ClassVar[str] = "azp_adj"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_layer(cls, layer: torch.nn.Module) -> "Int8Params":
|
||||||
|
"""Extract parameters from layer"""
|
||||||
|
return cls(
|
||||||
|
weight=getattr(layer, cls.WEIGHT),
|
||||||
|
weight_scale=getattr(layer, cls.WEIGHT_SCALE),
|
||||||
|
input_scale=getattr(layer, cls.INPUT_SCALE, None),
|
||||||
|
input_zero_point=getattr(layer, cls.INPUT_ZERO_POINT, None),
|
||||||
|
azp_adj=getattr(layer, cls.AZP_ADJ, None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_ParamsT = TypeVar("_ParamsT", bound=Params)
|
||||||
|
_ConfigT = TypeVar("_ConfigT", bound=MMLinearLayerConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class MMLinearKernel(ABC, Generic[_ConfigT, _ParamsT]):
|
||||||
|
"""Abstract base class for quantized matrix multiplication kernels.
|
||||||
|
|
||||||
|
This class provides the interface for implementing custom quantized linear layer
|
||||||
|
kernels in vLLM. Subclasses should implement specific quantization strategies
|
||||||
|
(e.g., FP8, INT8) and their corresponding compute kernels.
|
||||||
|
|
||||||
|
Generic Type Parameters:
|
||||||
|
_ConfigT: Configuration type for the kernel (subclass of MMLinearLayerConfig).
|
||||||
|
Contains kernel-specific settings like quantization keys, dtypes, etc.
|
||||||
|
_ParamsT: Parameter type for the kernel (subclass of Params).
|
||||||
|
Defines the quantized weights and scales needed by the kernel.
|
||||||
|
|
||||||
|
Typical Usage:
|
||||||
|
1. Define a config dataclass inheriting from MMLinearLayerConfig
|
||||||
|
2. Define a params dataclass inheriting from Params (or FP8Params/Int8Params)
|
||||||
|
3. Subclass MMLinearKernel with your config and params types
|
||||||
|
4. Implement all abstract methods
|
||||||
|
5. Register the kernel with the quantization method
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class MyKernelConfig(MMLinearLayerConfig):
|
||||||
|
static: bool
|
||||||
|
output_dtype: torch.dtype
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MyKernelParams(FP8Params):
|
||||||
|
custom_scale: torch.Tensor
|
||||||
|
CUSTOM_SCALE: ClassVar[str] = "custom_scale"
|
||||||
|
|
||||||
|
|
||||||
|
class MyKernel(MMLinearKernel[MyKernelConfig, MyKernelParams]):
|
||||||
|
@classmethod
|
||||||
|
def is_supported(cls, compute_capability=None):
|
||||||
|
if compute_capability and compute_capability < 90:
|
||||||
|
return False, "Requires compute capability >= 9.0"
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, config):
|
||||||
|
if not config.static:
|
||||||
|
return False, "Only static quantization supported"
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer):
|
||||||
|
# Preprocess weights for the kernel
|
||||||
|
params = self._get_layer_params(layer)
|
||||||
|
processed = preprocess_weights(params.weight)
|
||||||
|
replace_parameter(layer, params.WEIGHT, processed)
|
||||||
|
|
||||||
|
def _get_layer_params(self, layer, **kwargs):
|
||||||
|
return MyKernelParams.from_layer(layer)
|
||||||
|
|
||||||
|
def apply_weights(self, layer, x, bias=None, **kwargs):
|
||||||
|
params = self._get_layer_params(layer)
|
||||||
|
# Call your custom kernel
|
||||||
|
output = my_custom_kernel(x, params.weight, params.weight_scale)
|
||||||
|
if bias is not None:
|
||||||
|
output += bias
|
||||||
|
return output
|
||||||
|
```
|
||||||
|
|
||||||
|
Lifecycle:
|
||||||
|
1. Kernel selection: is_supported() and can_implement() check compatibility
|
||||||
|
2. Initialization: __init__() creates kernel instance with config
|
||||||
|
3. Weight loading: process_weights_after_loading() preprocesses weights
|
||||||
|
4. Inference: apply_weights() executes the quantized matmul
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def is_supported(
|
||||||
|
cls, compute_capability: int | None = None
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
"""Check if this kernel is supported on the current hardware.
|
||||||
|
|
||||||
|
This method checks hardware-level compatibility (e.g., GPU architecture,
|
||||||
|
compute capability, available instructions). It's called during kernel
|
||||||
|
selection to filter out kernels that cannot run on the current device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
compute_capability: GPU compute capability (e.g., 80 for A100, 90 for H100).
|
||||||
|
If None, should check the current device.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (is_supported, reason):
|
||||||
|
- is_supported: True if the kernel can run on this hardware
|
||||||
|
- reason: If not supported, a string explaining why; otherwise None
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def can_implement(cls, config: _ConfigT) -> tuple[bool, str | None]:
|
||||||
|
"""Check if this kernel can implement the given configuration.
|
||||||
|
|
||||||
|
This method checks configuration-level compatibility (e.g., quantization
|
||||||
|
scheme, group sizes, static vs dynamic quantization). It's called after
|
||||||
|
is_supported() to determine if this kernel can handle the specific
|
||||||
|
quantization configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The kernel configuration to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (can_implement, reason):
|
||||||
|
- can_implement: True if this kernel supports the config
|
||||||
|
- reason: If not supported, a string explaining why; otherwise None
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __init__(self, config: _ConfigT) -> None:
|
||||||
|
"""Initialize the kernel with the given configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Kernel-specific configuration containing settings like
|
||||||
|
quantization keys, output dtypes, etc.
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
"""Process and transform weights after loading from checkpoint.
|
||||||
|
|
||||||
|
This method is called once after weights are loaded but before inference.
|
||||||
|
Use it to preprocess weights into the format required by your kernel
|
||||||
|
(e.g., reordering, padding, format conversion).
|
||||||
|
|
||||||
|
Modifications should be done in-place using replace_parameter() to ensure
|
||||||
|
the layer's parameters are properly updated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: The layer module containing the weights to process
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
def process_weights_after_loading(self, layer):
|
||||||
|
params = self._get_layer_params(layer)
|
||||||
|
# Reorder weights for better memory access
|
||||||
|
weight_reordered = reorder_weights(params.weight)
|
||||||
|
replace_parameter(layer, params.WEIGHT, weight_reordered)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# return a covariant type in the subclass
|
||||||
|
@abstractmethod
|
||||||
|
def _get_layer_params(self, layer: torch.nn.Module, **kwargs: Any) -> _ParamsT:
|
||||||
|
"""Extract typed parameters from the layer module.
|
||||||
|
|
||||||
|
This internal method retrieves the quantized weights and scales from
|
||||||
|
the layer as a typed parameter object. Subclasses should typically
|
||||||
|
delegate to ParamsClass.from_layer().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: The layer module containing the parameters
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A typed parameter object containing weights, scales, and other
|
||||||
|
quantization parameters
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
def _get_layer_params(self, layer, **kwargs):
|
||||||
|
return MyKernelParams.from_layer(layer)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_output_padding(self) -> int | None:
|
||||||
|
"""Get the number of output tokens to pad for this kernel.
|
||||||
|
|
||||||
|
Some kernels require input padding for optimal performance.
|
||||||
|
Override this method to specify padding requirements.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens to pad, or None for no padding (default)
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply the quantized weights to the input tensor.
|
||||||
|
|
||||||
|
This is the main inference method that performs the quantized matrix
|
||||||
|
multiplication. It should handle input quantization (if needed), call
|
||||||
|
the underlying kernel, and apply bias.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: The layer module containing the quantized weights
|
||||||
|
x: Input tensor of shape [..., in_features]
|
||||||
|
bias: Optional bias tensor of shape [out_features]
|
||||||
|
**kwargs: Additional kernel-specific arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output tensor of shape [..., out_features]
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
@@ -0,0 +1,209 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
process_fp8_weight_block_strategy,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.utils import replace_parameter
|
||||||
|
|
||||||
|
from ..base import (
|
||||||
|
FP8Params,
|
||||||
|
MMLinearKernel,
|
||||||
|
)
|
||||||
|
from .ScaledMMLinearKernel import FP8ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FP8BlockParams(FP8Params):
|
||||||
|
weight_scale_inv: torch.Tensor | None
|
||||||
|
weight_scale: torch.Tensor | None
|
||||||
|
|
||||||
|
WEIGHT_SCALE_INV: ClassVar[str] = "weight_scale_inv"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_layer(cls, layer: torch.nn.Module) -> Self:
|
||||||
|
return cls(
|
||||||
|
weight=getattr(layer, cls.WEIGHT),
|
||||||
|
weight_scale_inv=getattr(layer, cls.WEIGHT_SCALE_INV, None),
|
||||||
|
weight_scale=getattr(layer, cls.WEIGHT_SCALE, None),
|
||||||
|
input_scale=getattr(layer, cls.INPUT_SCALE, None),
|
||||||
|
input_scale_ub=getattr(layer, cls.INPUT_SCALE_UB, None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Fp8BlockScaledMMLinearKernel(
|
||||||
|
MMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8BlockParams], ABC
|
||||||
|
):
|
||||||
|
# Set to False in subclasses that accept BF16 input directly (e.g. FlashInfer)
|
||||||
|
# and therefore do not need the input quantization step in apply_weights.
|
||||||
|
apply_input_quant: ClassVar[bool] = True
|
||||||
|
|
||||||
|
def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
act_scale_descriptor = config.activation_quant_key.scale
|
||||||
|
self.weight_group_shape = config.weight_quant_key.scale.group_shape
|
||||||
|
self.quant_fp8 = QuantFP8(
|
||||||
|
static=act_scale_descriptor.static,
|
||||||
|
group_shape=act_scale_descriptor.group_shape,
|
||||||
|
num_token_padding=self.get_output_padding(),
|
||||||
|
use_ue8m0=False,
|
||||||
|
)
|
||||||
|
self.use_triton = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
|
||||||
|
act_quant_key = config.activation_quant_key
|
||||||
|
if act_quant_key.scale.static:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"Only dynamic per token group activation quantization is supported.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def _get_layer_params(self, layer: torch.nn.Module, **kwargs) -> FP8BlockParams:
|
||||||
|
return FP8BlockParams.from_layer(layer)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||||
|
params = self._get_layer_params(layer)
|
||||||
|
# Fp8LinearMethod registered weight scale
|
||||||
|
# buffer as weight_scale_inv unlike compressed tensors.
|
||||||
|
weight_scale = (
|
||||||
|
params.weight_scale
|
||||||
|
if params.weight_scale_inv is None
|
||||||
|
else params.weight_scale_inv
|
||||||
|
)
|
||||||
|
scale_attr_name = (
|
||||||
|
params.WEIGHT_SCALE
|
||||||
|
if params.weight_scale_inv is None
|
||||||
|
else params.WEIGHT_SCALE_INV
|
||||||
|
)
|
||||||
|
new_weight, new_weight_scale = process_fp8_weight_block_strategy(
|
||||||
|
params.weight,
|
||||||
|
weight_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
replace_parameter(layer, params.WEIGHT, new_weight.data)
|
||||||
|
replace_parameter(layer, scale_attr_name, new_weight_scale.data)
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
out_dtype = self.config.out_dtype
|
||||||
|
params = self._get_layer_params(layer)
|
||||||
|
weight = params.weight
|
||||||
|
weight_scale = (
|
||||||
|
params.weight_scale
|
||||||
|
if params.weight_scale_inv is None
|
||||||
|
else params.weight_scale_inv
|
||||||
|
)
|
||||||
|
input_scale = params.input_scale
|
||||||
|
scale_up = params.input_scale_ub
|
||||||
|
|
||||||
|
# View input as 2D matrix for fp8 methods
|
||||||
|
input_2d = x.view(-1, x.shape[-1])
|
||||||
|
output_shape = [*x.shape[:-1], weight.shape[0]]
|
||||||
|
|
||||||
|
if self.apply_input_quant:
|
||||||
|
q_input, input_scale = self.quant_fp8(
|
||||||
|
input_2d, input_scale, scale_up, use_triton=self.use_triton
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q_input = input_2d
|
||||||
|
# Provide a concrete placeholder so apply_block_scaled_mm args are
|
||||||
|
# always Tensors. Subclasses with apply_input_quant=False must not
|
||||||
|
# use As in apply_block_scaled_mm.
|
||||||
|
input_scale = (
|
||||||
|
input_scale if input_scale is not None else input_2d.new_ones(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.apply_block_scaled_mm(
|
||||||
|
A=q_input,
|
||||||
|
B=weight,
|
||||||
|
As=input_scale,
|
||||||
|
Bs=weight_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
return output.to(dtype=out_dtype).view(*output_shape)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply_block_scaled_mm(
|
||||||
|
self,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class Fp8BlockScaledDynamicMMLinearKernel(Fp8BlockScaledMMLinearKernel, ABC):
|
||||||
|
"""Dynamic FP8 block-scaled kernel that dispatches at runtime.
|
||||||
|
|
||||||
|
Extends Fp8BlockScaledMMLinearKernel to inherit apply_weights and overrides
|
||||||
|
apply_block_scaled_mm to dispatch between two sub-kernels using torch.cond.
|
||||||
|
|
||||||
|
Subclasses must define:
|
||||||
|
base_type: The primary kernel class.
|
||||||
|
fallback_type: The fallback kernel class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
base_type: ClassVar[type[Fp8BlockScaledMMLinearKernel]]
|
||||||
|
fallback_type: ClassVar[type[Fp8BlockScaledMMLinearKernel]]
|
||||||
|
|
||||||
|
def __init__(self, config: "FP8ScaledMMLinearLayerConfig") -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
self.base = self.base_type(config)
|
||||||
|
self.fallback = self.fallback_type(config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_supported(
|
||||||
|
cls, compute_capability: int | None = None
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
is_base_supported, reason_1 = cls.base_type.is_supported(compute_capability)
|
||||||
|
is_fallback_supported, reason_2 = cls.fallback_type.is_supported(
|
||||||
|
compute_capability
|
||||||
|
)
|
||||||
|
if is_base_supported and is_fallback_supported:
|
||||||
|
return True, None
|
||||||
|
if not is_base_supported and not is_fallback_supported:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"base is not supported due to {reason_1}; "
|
||||||
|
f"fallback is not supported due to {reason_2}",
|
||||||
|
)
|
||||||
|
if not is_base_supported:
|
||||||
|
return False, f"base is not supported due to {reason_1}"
|
||||||
|
return False, f"fallback is not supported due to {reason_2}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(
|
||||||
|
cls, config: "FP8ScaledMMLinearLayerConfig"
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
can_implement_base, reason_1 = cls.base_type.can_implement(config)
|
||||||
|
can_implement_fallback, reason_2 = cls.fallback_type.can_implement(config)
|
||||||
|
if can_implement_base and can_implement_fallback:
|
||||||
|
return True, None
|
||||||
|
if not can_implement_base and not can_implement_fallback:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"base cannot implement due to {reason_1}; "
|
||||||
|
f"fallback cannot implement due to {reason_2}",
|
||||||
|
)
|
||||||
|
if not can_implement_base:
|
||||||
|
return False, f"base cannot implement due to {reason_1}"
|
||||||
|
return False, f"fallback cannot implement due to {reason_2}"
|
||||||
@@ -14,14 +14,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from ..base import MMLinearLayerConfig
|
||||||
@dataclass
|
|
||||||
class ScaledMMLinearLayerConfig:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
class Int8ScaledMMLinearLayerConfig(MMLinearLayerConfig):
|
||||||
# TODO: Change to QuantKey like FP8ScaledMMLinearLayerConfig
|
# TODO: Change to QuantKey like FP8ScaledMMLinearLayerConfig
|
||||||
is_static_input_scheme: bool
|
is_static_input_scheme: bool
|
||||||
is_channelwise: bool
|
is_channelwise: bool
|
||||||
@@ -29,10 +26,12 @@ class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
class FP8ScaledMMLinearLayerConfig(MMLinearLayerConfig):
|
||||||
weight_quant_key: QuantKey
|
weight_quant_key: QuantKey
|
||||||
activation_quant_key: QuantKey
|
activation_quant_key: QuantKey
|
||||||
out_dtype: torch.dtype | None
|
weight_shape: tuple[int, int]
|
||||||
|
input_dtype: torch.dtype
|
||||||
|
out_dtype: torch.dtype
|
||||||
|
|
||||||
|
|
||||||
_FP8ParamsT = tuple[
|
_FP8ParamsT = tuple[
|
||||||
@@ -50,7 +49,7 @@ _Int8ParamsT = tuple[
|
|||||||
]
|
]
|
||||||
|
|
||||||
_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT)
|
_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT)
|
||||||
_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig)
|
_ConfigT = TypeVar("_ConfigT", bound=MMLinearLayerConfig)
|
||||||
|
|
||||||
|
|
||||||
class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC):
|
class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC):
|
||||||
|
|||||||
@@ -4,6 +4,9 @@
|
|||||||
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
|
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
|
||||||
AiterInt8ScaledMMLinearKernel,
|
AiterInt8ScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.kernels.linear.scaled_mm.BlockScaledMMLinearKernel import (
|
||||||
|
Fp8BlockScaledMMLinearKernel,
|
||||||
|
)
|
||||||
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
|
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
|
||||||
CPUInt8ScaledMMLinearKernel,
|
CPUInt8ScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
@@ -31,7 +34,6 @@ from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
|
|||||||
Int8ScaledMMLinearKernel,
|
Int8ScaledMMLinearKernel,
|
||||||
Int8ScaledMMLinearLayerConfig,
|
Int8ScaledMMLinearLayerConfig,
|
||||||
ScaledMMLinearKernel,
|
ScaledMMLinearKernel,
|
||||||
ScaledMMLinearLayerConfig,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
|
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
|
||||||
TritonInt8ScaledMMLinearKernel,
|
TritonInt8ScaledMMLinearKernel,
|
||||||
@@ -55,4 +57,5 @@ __all__ = [
|
|||||||
"RowWiseTorchFP8ScaledMMLinearKernel",
|
"RowWiseTorchFP8ScaledMMLinearKernel",
|
||||||
"ROCmFP8ScaledMMLinearKernel",
|
"ROCmFP8ScaledMMLinearKernel",
|
||||||
"TritonInt8ScaledMMLinearKernel",
|
"TritonInt8ScaledMMLinearKernel",
|
||||||
|
"Fp8BlockScaledMMLinearKernel",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -6,8 +6,15 @@ import torch
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .BlockScaledMMLinearKernel import (
|
||||||
|
Fp8BlockScaledMMLinearKernel,
|
||||||
|
FP8ScaledMMLinearLayerConfig,
|
||||||
|
)
|
||||||
from .cutlass import CutlassInt8ScaledMMLinearKernel
|
from .cutlass import CutlassInt8ScaledMMLinearKernel
|
||||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
@@ -107,3 +114,54 @@ class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
|
|||||||
# b to be [N, K]
|
# b to be [N, K]
|
||||||
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||||
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
|
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class AiterFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
|
||||||
|
def __init__(self, config: FP8ScaledMMLinearLayerConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
n, k = config.weight_shape
|
||||||
|
|
||||||
|
self.use_triton = (
|
||||||
|
not current_platform.is_fp8_fnuz()
|
||||||
|
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_supported(cls, compute_capability=None):
|
||||||
|
return (
|
||||||
|
rocm_aiter_ops.is_linear_enabled(),
|
||||||
|
"Only supported on ROCm platform \
|
||||||
|
with aiter package installed.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
|
||||||
|
can_implement_base, reason = super().can_implement(config)
|
||||||
|
if not can_implement_base:
|
||||||
|
return can_implement_base, reason
|
||||||
|
|
||||||
|
act_quant_desc = config.activation_quant_key.scale
|
||||||
|
if act_quant_desc.group_shape != GroupShape(1, 128):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"Supports only dynamic per token group activation "
|
||||||
|
"quantization with group_shape=(1,128).",
|
||||||
|
)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def apply_block_scaled_mm(
|
||||||
|
self,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
out_dtype = self.config.out_dtype
|
||||||
|
if self.use_triton:
|
||||||
|
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
|
||||||
|
else:
|
||||||
|
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
|
||||||
|
|
||||||
|
return gemm_a8w8_blockscale_op(
|
||||||
|
A, B, As, Bs, list(self.weight_group_shape), output_dtype=out_dtype
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,12 +5,19 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||||
convert_to_channelwise,
|
convert_to_channelwise,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .BlockScaledMMLinearKernel import Fp8BlockScaledMMLinearKernel
|
||||||
from .ScaledMMLinearKernel import (
|
from .ScaledMMLinearKernel import (
|
||||||
FP8ScaledMMLinearKernel,
|
FP8ScaledMMLinearKernel,
|
||||||
FP8ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
@@ -171,3 +178,143 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||||
)
|
)
|
||||||
return output.view(*output_shape)
|
return output.view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class CutlassFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
|
||||||
|
def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
act_scale_descriptor = config.activation_quant_key.scale
|
||||||
|
self.weight_group_shape = config.weight_quant_key.scale.group_shape
|
||||||
|
self.quant_fp8 = QuantFP8(
|
||||||
|
static=act_scale_descriptor.static,
|
||||||
|
group_shape=act_scale_descriptor.group_shape,
|
||||||
|
num_token_padding=self.get_output_padding(),
|
||||||
|
use_ue8m0=False,
|
||||||
|
column_major_scales=True,
|
||||||
|
)
|
||||||
|
self.is_hopper = current_platform.is_device_capability(90)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_supported(cls, compute_capability=None):
|
||||||
|
if not CUTLASS_BLOCK_FP8_SUPPORTED:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"The device compute capability of"
|
||||||
|
f"{compute_capability} is not supported.",
|
||||||
|
)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
|
||||||
|
can_implement_base, reason = super().can_implement(config)
|
||||||
|
if not can_implement_base:
|
||||||
|
return can_implement_base, reason
|
||||||
|
|
||||||
|
act_quant_desc = config.activation_quant_key.scale
|
||||||
|
if act_quant_desc.group_shape != GroupShape(1, 128):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"Supports only dynamic per token group activation "
|
||||||
|
"quantization with group_shape=(1,128).",
|
||||||
|
)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def apply_block_scaled_mm(
|
||||||
|
self,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
out_dtype = self.config.out_dtype
|
||||||
|
if self.is_hopper:
|
||||||
|
return torch.ops.vllm.padded_cutlass(
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
list(self.weight_group_shape),
|
||||||
|
out_dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ops.cutlass_scaled_mm(
|
||||||
|
A,
|
||||||
|
B.T,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
scale_a=As,
|
||||||
|
scale_b=Bs.T,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_scaled_mm(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
block_size: list[int],
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return ops.cutlass_scaled_mm(
|
||||||
|
A,
|
||||||
|
B.T,
|
||||||
|
out_dtype=output_dtype,
|
||||||
|
scale_a=As,
|
||||||
|
scale_b=Bs.T,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _padded_cutlass(
|
||||||
|
qx: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
x_scale: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
block_size: list[int],
|
||||||
|
output_dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
pad_multiple = 4
|
||||||
|
dim = qx.shape[0]
|
||||||
|
padded = (
|
||||||
|
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
|
||||||
|
)
|
||||||
|
|
||||||
|
has_pad = padded > dim
|
||||||
|
|
||||||
|
if has_pad:
|
||||||
|
padded_shape = [padded, *qx.shape[1:]]
|
||||||
|
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
|
||||||
|
padded_qx[0 : qx.shape[0], ...].copy_(qx)
|
||||||
|
|
||||||
|
padded_x_scale_shape = [*x_scale.shape[1:], padded]
|
||||||
|
padded_x_scale = torch.ones(
|
||||||
|
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
|
||||||
|
).permute(-1, -2)
|
||||||
|
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
|
||||||
|
|
||||||
|
output = cutlass_scaled_mm(
|
||||||
|
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
|
||||||
|
)
|
||||||
|
return output[0 : qx.shape[0], ...]
|
||||||
|
else:
|
||||||
|
return cutlass_scaled_mm(
|
||||||
|
qx, weight, x_scale, weight_scale, block_size, output_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _padded_cutlass_fake(
|
||||||
|
qx: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
x_scale: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
block_size: list[int],
|
||||||
|
output_dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty(
|
||||||
|
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
"padded_cutlass",
|
||||||
|
_padded_cutlass,
|
||||||
|
fake_impl=_padded_cutlass_fake,
|
||||||
|
)
|
||||||
|
|||||||
156
vllm/model_executor/kernels/linear/scaled_mm/deep_gemm.py
Normal file
156
vllm/model_executor/kernels/linear/scaled_mm/deep_gemm.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
deepgemm_post_process_fp8_weight_block,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.utils import replace_parameter
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.deep_gemm import (
|
||||||
|
fp8_gemm_nt,
|
||||||
|
is_deep_gemm_e8m0_used,
|
||||||
|
is_deep_gemm_supported,
|
||||||
|
should_auto_disable_deep_gemm,
|
||||||
|
should_use_deepgemm_for_fp8_linear,
|
||||||
|
)
|
||||||
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .BlockScaledMMLinearKernel import (
|
||||||
|
Fp8BlockScaledMMLinearKernel,
|
||||||
|
FP8ScaledMMLinearLayerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepGemmFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
|
||||||
|
def __init__(self, config: FP8ScaledMMLinearLayerConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
|
||||||
|
act_scale_descriptor = config.activation_quant_key.scale
|
||||||
|
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||||
|
self.quant_fp8 = QuantFP8(
|
||||||
|
static=False,
|
||||||
|
group_shape=act_scale_descriptor.group_shape,
|
||||||
|
use_ue8m0=self.use_deep_gemm_e8m0,
|
||||||
|
tma_aligned_scales=envs.VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES,
|
||||||
|
column_major_scales=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_supported(cls, compute_capability=None):
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False, "DeepGEMM is only supported on cuda platform"
|
||||||
|
if not is_deep_gemm_supported():
|
||||||
|
return False, "Currently, only Hopper and Blackwell GPUs are supported."
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, config):
|
||||||
|
can_implement_base, reason = super().can_implement(config)
|
||||||
|
if not can_implement_base:
|
||||||
|
return can_implement_base, reason
|
||||||
|
if config.out_dtype != torch.bfloat16:
|
||||||
|
return (False, "Supports only output dtype of bfloat16")
|
||||||
|
|
||||||
|
act_quant_desc = config.activation_quant_key.scale
|
||||||
|
if act_quant_desc.group_shape != GroupShape(1, 128):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"Supports only dynamic per token group activation "
|
||||||
|
"quantization with group_shape=(1,128).",
|
||||||
|
)
|
||||||
|
model_config = get_current_vllm_config().model_config
|
||||||
|
|
||||||
|
if model_config is None:
|
||||||
|
return False, "Model configuration is required."
|
||||||
|
|
||||||
|
model_type = getattr(model_config.hf_text_config, "model_type", None)
|
||||||
|
if should_auto_disable_deep_gemm(model_type):
|
||||||
|
return False, f"Should not use deepgemm for model {model_type}"
|
||||||
|
|
||||||
|
if not should_use_deepgemm_for_fp8_linear(
|
||||||
|
config.out_dtype, config.weight_shape
|
||||||
|
):
|
||||||
|
return False, "The provided metadata is not supported."
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer):
|
||||||
|
super().process_weights_after_loading(layer)
|
||||||
|
params = self._get_layer_params(layer)
|
||||||
|
assert layer.weight_block_size is not None
|
||||||
|
|
||||||
|
if self.is_deep_gemm_supported:
|
||||||
|
weight_scale_invs = params.weight_scale_inv
|
||||||
|
scale_attr = (
|
||||||
|
params.WEIGHT_SCALE_INV
|
||||||
|
if weight_scale_invs is not None
|
||||||
|
else params.WEIGHT_SCALE
|
||||||
|
)
|
||||||
|
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
|
||||||
|
wq=params.weight,
|
||||||
|
ws=weight_scale_invs
|
||||||
|
if weight_scale_invs is not None
|
||||||
|
else params.weight_scale,
|
||||||
|
quant_block_shape=tuple(layer.weight_block_size),
|
||||||
|
use_e8m0=self.use_deep_gemm_e8m0,
|
||||||
|
)
|
||||||
|
replace_parameter(layer, params.WEIGHT, dg_weight)
|
||||||
|
replace_parameter(layer, scale_attr, dg_weight_scale)
|
||||||
|
|
||||||
|
def apply_block_scaled_mm(
|
||||||
|
self,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
out_dtype = self.config.out_dtype
|
||||||
|
output = torch.empty(
|
||||||
|
(A.shape[0], B.shape[0]),
|
||||||
|
dtype=out_dtype,
|
||||||
|
device=A.device,
|
||||||
|
)
|
||||||
|
torch.ops.vllm.fp8_gemm_nt_op(A, As, B, Bs, output, self.use_deep_gemm_e8m0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _fp8_gemm_nt_op(
|
||||||
|
q_input: torch.Tensor,
|
||||||
|
input_scale: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
use_deep_gemm_e8m0: bool,
|
||||||
|
) -> None:
|
||||||
|
fp8_gemm_nt(
|
||||||
|
(q_input, input_scale),
|
||||||
|
(weight, weight_scale),
|
||||||
|
output,
|
||||||
|
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fp8_gemm_nt_op_fake(
|
||||||
|
q_input: torch.Tensor,
|
||||||
|
input_scale: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
use_deep_gemm_e8m0: bool,
|
||||||
|
) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
"fp8_gemm_nt_op",
|
||||||
|
_fp8_gemm_nt_op,
|
||||||
|
mutates_args=["output"],
|
||||||
|
fake_impl=_fp8_gemm_nt_op_fake,
|
||||||
|
)
|
||||||
@@ -2,11 +2,32 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
per_token_group_quant_fp8,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
from vllm.utils.flashinfer import (
|
||||||
|
flashinfer_fp8_blockscale_gemm,
|
||||||
|
flashinfer_scaled_fp8_mm,
|
||||||
|
has_flashinfer,
|
||||||
|
is_flashinfer_fp8_blockscale_gemm_supported,
|
||||||
|
should_use_flashinfer_for_blockscale_fp8_gemm,
|
||||||
|
)
|
||||||
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .BlockScaledMMLinearKernel import (
|
||||||
|
Fp8BlockScaledDynamicMMLinearKernel,
|
||||||
|
Fp8BlockScaledMMLinearKernel,
|
||||||
|
)
|
||||||
|
from .deep_gemm import DeepGemmFp8BlockScaledMMKernel, fp8_gemm_nt
|
||||||
from .ScaledMMLinearKernel import (
|
from .ScaledMMLinearKernel import (
|
||||||
FP8ScaledMMLinearKernel,
|
FP8ScaledMMLinearKernel,
|
||||||
FP8ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
@@ -55,3 +76,256 @@ class FlashInferFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
return flashinfer_scaled_fp8_mm(
|
return flashinfer_scaled_fp8_mm(
|
||||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashInferFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
|
||||||
|
# FlashInfer accepts BF16 input and handles FP8 conversion internally.
|
||||||
|
apply_input_quant: ClassVar[bool] = False
|
||||||
|
|
||||||
|
def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
|
||||||
|
can_implement_base, reason = super().can_implement(config)
|
||||||
|
if not can_implement_base:
|
||||||
|
return can_implement_base, reason
|
||||||
|
|
||||||
|
act_quant_desc = config.activation_quant_key.scale
|
||||||
|
if act_quant_desc.group_shape != GroupShape(1, 128):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"Supports only dynamic per token group activation "
|
||||||
|
"quantization with group_shape=(1,128).",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not should_use_flashinfer_for_blockscale_fp8_gemm(
|
||||||
|
is_flashinfer_fp8_blockscale_gemm_supported(),
|
||||||
|
config.out_dtype,
|
||||||
|
config.input_dtype,
|
||||||
|
config.weight_quant_key.dtype,
|
||||||
|
config.weight_shape,
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"The provided metadata is not supported.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_supported(cls, compute_capability=None):
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False, "only cuda devices are supported."
|
||||||
|
|
||||||
|
if not is_flashinfer_fp8_blockscale_gemm_supported():
|
||||||
|
return False, "FlashInfer block-scale FP8 GEMM is not available."
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def apply_block_scaled_mm(
|
||||||
|
self,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# A is BF16 — FlashInfer handles FP8 conversion internally.
|
||||||
|
# As is a placeholder (apply_input_quant=False) and is not used here.
|
||||||
|
return torch.ops.vllm.flashinfer_fp8_blockscale_gemm(
|
||||||
|
A, # BF16 input
|
||||||
|
B, # FP8 weight
|
||||||
|
Bs, # Weight scales
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashInferFp8DeepGEMMDynamicBlockScaledKernel(
|
||||||
|
Fp8BlockScaledDynamicMMLinearKernel
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Conditional FlashInfer / DeepGEMM FP8 block-scaled GEMM.
|
||||||
|
|
||||||
|
Dispatches between two kernels based on input batch size:
|
||||||
|
- Small batches (M < 32): FlashInfer's swapAB trick for better utilisation.
|
||||||
|
- Large batches (M >= 32): DeepGEMM for peak throughput.
|
||||||
|
|
||||||
|
apply_input_quant is False because FlashInfer accepts BF16 input and
|
||||||
|
handles FP8 conversion internally. The DeepGEMM branch therefore
|
||||||
|
quantises BF16→FP8 inside apply_mm via a closure before dispatching to
|
||||||
|
the DeepGEMM kernel — keeping both branches compatible with the single
|
||||||
|
BF16 tensor operand list passed by torch.cond.
|
||||||
|
"""
|
||||||
|
|
||||||
|
base_type: ClassVar[type[FlashInferFp8BlockScaledMMKernel]] = (
|
||||||
|
FlashInferFp8BlockScaledMMKernel
|
||||||
|
)
|
||||||
|
fallback_type: ClassVar[type[DeepGemmFp8BlockScaledMMKernel]] = (
|
||||||
|
DeepGemmFp8BlockScaledMMKernel
|
||||||
|
)
|
||||||
|
apply_input_quant: ClassVar[bool] = False
|
||||||
|
|
||||||
|
def __init__(self, config: FP8ScaledMMLinearLayerConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.base: FlashInferFp8BlockScaledMMKernel
|
||||||
|
self.fallback: DeepGemmFp8BlockScaledMMKernel
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||||
|
# DeepGEMM need post-processing; both kernels share the same
|
||||||
|
# parameter tensor layout so processing once is sufficient.
|
||||||
|
self.fallback.process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
def apply_block_scaled_mm(
|
||||||
|
self,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
group_size = self.weight_group_shape.col
|
||||||
|
use_deep_gemm_e8m0 = self.fallback.use_deep_gemm_e8m0
|
||||||
|
|
||||||
|
return torch.ops.vllm.dynamic_flashinfer_deepgemm_blockscale_gemm(
|
||||||
|
A, B, Bs, group_size, use_deep_gemm_e8m0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _flashinfer_fp8_blockscale_gemm_impl(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return flashinfer_fp8_blockscale_gemm(
|
||||||
|
input=input,
|
||||||
|
weight=weight,
|
||||||
|
weight_scale=weight_scale,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _flashinfer_fp8_blockscale_gemm_fake(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Required fake/meta implementation for torch.compile graph tracing.
|
||||||
|
"""
|
||||||
|
return torch.empty(
|
||||||
|
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
"flashinfer_fp8_blockscale_gemm",
|
||||||
|
_flashinfer_fp8_blockscale_gemm_impl,
|
||||||
|
fake_impl=_flashinfer_fp8_blockscale_gemm_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _dynamic_flashinfer_deepgemm_blockscale_gemm_impl(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
use_deep_gemm_e8m0: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection.
|
||||||
|
|
||||||
|
This function switches between two optimized kernels based on the input batch size:
|
||||||
|
- For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization.
|
||||||
|
- For larger batches (M >= 32): Uses the official DeepGEMM kernel.
|
||||||
|
|
||||||
|
The conditional logic must use torch.cond() instead of a simple if-else statement
|
||||||
|
to maintain compatibility with torch.compile graph compilation.
|
||||||
|
|
||||||
|
This batch-size-dependent selection is essential for maintaining model accuracy.
|
||||||
|
Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1
|
||||||
|
when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accuracy
|
||||||
|
drop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: Input tensor of shape (batch_size, input_dim) in FP8 format
|
||||||
|
weight: Weight tensor of shape (output_dim, input_dim) in FP8 format
|
||||||
|
weight_scale: Scale factors for weight quantization (per-group)
|
||||||
|
group_size: Quantization group size for the weight tensor
|
||||||
|
use_deep_gemm_e8m0: Whether to use the E8M0 format in DeepGEMM quantization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output tensor of shape (batch_size, output_dim) in bfloat16 format
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run_flashinfer_deepgemm_swapAB(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return flashinfer_fp8_blockscale_gemm(
|
||||||
|
input=input,
|
||||||
|
weight=weight,
|
||||||
|
weight_scale=weight_scale,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_deepgemm(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
q_input, input_scale = per_token_group_quant_fp8(
|
||||||
|
input,
|
||||||
|
group_size=group_size,
|
||||||
|
column_major_scales=True,
|
||||||
|
use_ue8m0=use_deep_gemm_e8m0,
|
||||||
|
)
|
||||||
|
output = torch.empty(
|
||||||
|
(q_input.shape[0], weight.shape[0]),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=q_input.device,
|
||||||
|
)
|
||||||
|
fp8_gemm_nt(
|
||||||
|
(q_input, input_scale),
|
||||||
|
(weight, weight_scale),
|
||||||
|
output,
|
||||||
|
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
if envs.VLLM_BATCH_INVARIANT:
|
||||||
|
return run_deepgemm(input, weight, weight_scale)
|
||||||
|
|
||||||
|
condition = input.shape[0] < 32
|
||||||
|
|
||||||
|
# PyTorch's torch.compile cannot handle input-dependent control flow in standard
|
||||||
|
# Python conditionals. torch.cond() explicitly registers both code paths in the
|
||||||
|
# computation graph, allowing torch.compile to capture both branches.
|
||||||
|
# without torch.cond, the M < 32 condition won't be able to be captured by torch
|
||||||
|
# compile
|
||||||
|
return torch.cond(
|
||||||
|
condition,
|
||||||
|
run_flashinfer_deepgemm_swapAB,
|
||||||
|
run_deepgemm,
|
||||||
|
(input, weight, weight_scale),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _dynamic_flashinfer_deepgemm_blockscale_gemm_fake(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
use_deep_gemm_e8m0: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Required fake/meta implementation for torch.compile graph tracing.
|
||||||
|
"""
|
||||||
|
return torch.empty(
|
||||||
|
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
"dynamic_flashinfer_deepgemm_blockscale_gemm",
|
||||||
|
_dynamic_flashinfer_deepgemm_blockscale_gemm_impl,
|
||||||
|
fake_impl=_dynamic_flashinfer_deepgemm_blockscale_gemm_fake,
|
||||||
|
)
|
||||||
|
|||||||
@@ -13,7 +13,11 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
convert_to_channelwise,
|
convert_to_channelwise,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .BlockScaledMMLinearKernel import (
|
||||||
|
Fp8BlockScaledMMLinearKernel,
|
||||||
|
)
|
||||||
from .cutlass import CutlassInt8ScaledMMLinearKernel
|
from .cutlass import CutlassInt8ScaledMMLinearKernel
|
||||||
from .ScaledMMLinearKernel import (
|
from .ScaledMMLinearKernel import (
|
||||||
Int8ScaledMMLinearLayerConfig,
|
Int8ScaledMMLinearLayerConfig,
|
||||||
@@ -150,3 +154,67 @@ class TritonInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
|
|||||||
out -= (x_s * w_s_row * azp_adj).to(x.dtype)
|
out -= (x_s * w_s_row * azp_adj).to(x.dtype)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TritonFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
|
||||||
|
@classmethod
|
||||||
|
def is_supported(cls, compute_capability=None):
|
||||||
|
if not current_platform.is_cuda_alike():
|
||||||
|
return False, "only cuda like devices are supported."
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def apply_block_scaled_mm(
|
||||||
|
self,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
list(self.weight_group_shape),
|
||||||
|
self.config.out_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO we should be able to change the type of block_size to GroupShape
|
||||||
|
# after we resolve GroupShape compilation issue
|
||||||
|
# https://github.com/vllm-project/vllm/issues/25270
|
||||||
|
def _w8a8_triton_block_scaled_mm_func(
|
||||||
|
qx: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
x_scale: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
block_size: list[int],
|
||||||
|
output_dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
w8a8_triton_block_scaled_mm,
|
||||||
|
)
|
||||||
|
|
||||||
|
return w8a8_triton_block_scaled_mm(
|
||||||
|
qx, weight, x_scale, weight_scale, block_size, output_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _w8a8_triton_block_scaled_mm_fake(
|
||||||
|
qx: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
x_scale: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
block_size: list[int],
|
||||||
|
output_dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty(
|
||||||
|
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
"w8a8_triton_block_scaled_mm_func",
|
||||||
|
_w8a8_triton_block_scaled_mm_func,
|
||||||
|
fake_impl=_w8a8_triton_block_scaled_mm_fake,
|
||||||
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrate
|
|||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.kernels.linear import (
|
from vllm.model_executor.kernels.linear import (
|
||||||
init_fp8_linear_kernel,
|
init_fp8_linear_kernel,
|
||||||
@@ -16,18 +17,16 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|||||||
CompressedTensorsScheme,
|
CompressedTensorsScheme,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
W8A8BlockFp8LinearOp,
|
|
||||||
create_fp8_input_scale,
|
create_fp8_input_scale,
|
||||||
create_fp8_scale_parameter,
|
create_fp8_scale_parameter,
|
||||||
create_fp8_weight_parameter,
|
create_fp8_weight_parameter,
|
||||||
maybe_post_process_fp8_weight_block,
|
|
||||||
process_fp8_weight_block_strategy,
|
|
||||||
process_fp8_weight_channel_strategy,
|
process_fp8_weight_channel_strategy,
|
||||||
process_fp8_weight_tensor_strategy,
|
process_fp8_weight_tensor_strategy,
|
||||||
validate_fp8_block_shape,
|
validate_fp8_block_shape,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape,
|
GroupShape,
|
||||||
|
create_fp8_quant_key,
|
||||||
kFp8DynamicTokenSym,
|
kFp8DynamicTokenSym,
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
kFp8StaticTokenSym,
|
kFp8StaticTokenSym,
|
||||||
@@ -67,6 +66,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
self.weight_quant = weight_quant
|
self.weight_quant = weight_quant
|
||||||
self.strategy = weight_quant.strategy
|
self.strategy = weight_quant.strategy
|
||||||
self.out_dtype = torch.get_default_dtype()
|
self.out_dtype = torch.get_default_dtype()
|
||||||
|
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||||
self.is_static_input_scheme = is_static_input_scheme
|
self.is_static_input_scheme = is_static_input_scheme
|
||||||
self.weight_block_size = self.weight_quant.block_structure
|
self.weight_block_size = self.weight_quant.block_structure
|
||||||
|
|
||||||
@@ -75,21 +75,18 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||||
assert not self.is_static_input_scheme
|
assert not self.is_static_input_scheme
|
||||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
|
||||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
self.weight_quant_key = create_fp8_quant_key(
|
||||||
act_quant_group_shape=self.act_q_group_shape,
|
static=True, group_shape=GroupShape(*self.weight_block_size)
|
||||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
)
|
||||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
self.activation_quant_key = create_fp8_quant_key(
|
||||||
|
static=False, group_shape=self.act_q_group_shape
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
activation_quant_key = activation_quant_key_mapping[is_static_input_scheme]
|
self.activation_quant_key = activation_quant_key_mapping[
|
||||||
weight_quant_key = weight_quant_key_mapping[self.strategy]
|
is_static_input_scheme
|
||||||
self.fp8_linear = init_fp8_linear_kernel(
|
]
|
||||||
activation_quant_key=activation_quant_key,
|
self.weight_quant_key = weight_quant_key_mapping[self.strategy]
|
||||||
weight_quant_key=weight_quant_key,
|
|
||||||
out_dtype=self.out_dtype,
|
|
||||||
module_name=self.__class__.__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@@ -146,6 +143,15 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
|
input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
|
||||||
layer.register_parameter("input_scale", input_scale)
|
layer.register_parameter("input_scale", input_scale)
|
||||||
|
|
||||||
|
self.fp8_linear = init_fp8_linear_kernel(
|
||||||
|
activation_quant_key=self.activation_quant_key,
|
||||||
|
weight_quant_key=self.weight_quant_key,
|
||||||
|
weight_shape=layer.weight.shape,
|
||||||
|
input_dtype=self.input_dtype,
|
||||||
|
out_dtype=self.out_dtype,
|
||||||
|
module_name=self.__class__.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer) -> None:
|
def process_weights_after_loading(self, layer) -> None:
|
||||||
if self.strategy == QuantizationStrategy.TENSOR:
|
if self.strategy == QuantizationStrategy.TENSOR:
|
||||||
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
|
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
|
||||||
@@ -163,10 +169,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
|
|
||||||
elif self.strategy == QuantizationStrategy.BLOCK:
|
elif self.strategy == QuantizationStrategy.BLOCK:
|
||||||
assert self.is_static_input_scheme is False
|
assert self.is_static_input_scheme is False
|
||||||
weight, weight_scale = process_fp8_weight_block_strategy(
|
self.fp8_linear.process_weights_after_loading(layer)
|
||||||
layer.weight, layer.weight_scale
|
|
||||||
)
|
layer.input_scale = None
|
||||||
input_scale = None
|
# fp8_linear.process_weights_after_loading applies the post process
|
||||||
|
# and reassigns the weight and weight_scale buffers to layer attributes.
|
||||||
|
return
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -185,8 +193,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
|
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
|
||||||
else:
|
else:
|
||||||
layer.input_scale = None
|
layer.input_scale = None
|
||||||
if self.strategy == QuantizationStrategy.BLOCK:
|
|
||||||
maybe_post_process_fp8_weight_block(layer)
|
|
||||||
|
|
||||||
if hasattr(self, "fp8_linear"):
|
if hasattr(self, "fp8_linear"):
|
||||||
self.fp8_linear.process_weights_after_loading(layer)
|
self.fp8_linear.process_weights_after_loading(layer)
|
||||||
@@ -197,13 +203,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.weight_block_size is not None:
|
|
||||||
return self.w8a8_block_fp8_linear.apply(
|
|
||||||
input=x,
|
|
||||||
weight=layer.weight,
|
|
||||||
weight_scale=layer.weight_scale,
|
|
||||||
input_scale=layer.input_scale,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import torch
|
|||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.kernels.linear import (
|
from vllm.model_executor.kernels.linear import (
|
||||||
init_fp8_linear_kernel,
|
init_fp8_linear_kernel,
|
||||||
@@ -93,12 +94,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
def __init__(self, quant_config: FBGEMMFp8Config):
|
def __init__(self, quant_config: FBGEMMFp8Config):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.out_dtype = torch.get_default_dtype()
|
self.out_dtype = torch.get_default_dtype()
|
||||||
self.fp8_linear = init_fp8_linear_kernel(
|
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||||
activation_quant_key=kFp8DynamicTokenSym,
|
|
||||||
weight_quant_key=kFp8StaticTokenSym,
|
|
||||||
out_dtype=torch.get_default_dtype(),
|
|
||||||
module_name=self.__class__.__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@@ -149,6 +145,15 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
layer.input_scale_ub = input_scale_ub
|
layer.input_scale_ub = input_scale_ub
|
||||||
|
|
||||||
|
self.fp8_linear = init_fp8_linear_kernel(
|
||||||
|
activation_quant_key=kFp8DynamicTokenSym,
|
||||||
|
weight_quant_key=kFp8StaticTokenSym,
|
||||||
|
weight_shape=layer.weight.shape,
|
||||||
|
input_dtype=self.input_dtype,
|
||||||
|
out_dtype=self.out_dtype,
|
||||||
|
module_name=self.__class__.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
# required by torch.compile
|
# required by torch.compile
|
||||||
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.kernels.linear import (
|
from vllm.model_executor.kernels.linear import (
|
||||||
@@ -45,13 +45,10 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
W8A8BlockFp8LinearOp,
|
|
||||||
create_fp8_input_scale,
|
create_fp8_input_scale,
|
||||||
create_fp8_scale_parameter,
|
create_fp8_scale_parameter,
|
||||||
create_fp8_weight_parameter,
|
create_fp8_weight_parameter,
|
||||||
maybe_post_process_fp8_weight_block,
|
|
||||||
process_fp8_input_tensor_strategy_moe,
|
process_fp8_input_tensor_strategy_moe,
|
||||||
process_fp8_weight_block_strategy,
|
|
||||||
process_fp8_weight_tensor_strategy,
|
process_fp8_weight_tensor_strategy,
|
||||||
process_fp8_weight_tensor_strategy_moe,
|
process_fp8_weight_tensor_strategy_moe,
|
||||||
validate_fp8_block_shape,
|
validate_fp8_block_shape,
|
||||||
@@ -61,6 +58,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape,
|
GroupShape,
|
||||||
|
create_fp8_quant_key,
|
||||||
is_layer_skipped,
|
is_layer_skipped,
|
||||||
kFp8Dynamic128Sym,
|
kFp8Dynamic128Sym,
|
||||||
kFp8DynamicTensorSym,
|
kFp8DynamicTensorSym,
|
||||||
@@ -273,12 +271,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||||
self.out_dtype = torch.get_default_dtype()
|
self.out_dtype = torch.get_default_dtype()
|
||||||
|
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||||
|
|
||||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||||
# kernel for fast weight-only FP8 quantization
|
# kernel for fast weight-only FP8 quantization
|
||||||
self.marlin_input_dtype = None
|
self.marlin_input_dtype = None
|
||||||
|
self.use_marlin = False
|
||||||
|
|
||||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
|
||||||
if self.quant_config.use_deep_gemm is not None:
|
if self.quant_config.use_deep_gemm is not None:
|
||||||
self.use_deep_gemm = self.quant_config.use_deep_gemm
|
self.use_deep_gemm = self.quant_config.use_deep_gemm
|
||||||
else:
|
else:
|
||||||
@@ -288,37 +287,26 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
self.block_quant = self.weight_block_size is not None
|
self.block_quant = self.weight_block_size is not None
|
||||||
self.act_q_static = self.quant_config.activation_scheme == "static"
|
self.act_q_static = self.quant_config.activation_scheme == "static"
|
||||||
|
|
||||||
# Use per-token quantization for better perf if dynamic and cutlass
|
|
||||||
if self.act_q_static:
|
|
||||||
activation_quant_key = kFp8StaticTensorSym
|
|
||||||
elif cutlass_fp8_supported():
|
|
||||||
activation_quant_key = kFp8DynamicTokenSym
|
|
||||||
else:
|
|
||||||
activation_quant_key = kFp8DynamicTensorSym
|
|
||||||
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
weight_quant_key = kFp8Static128BlockSym
|
|
||||||
else:
|
|
||||||
weight_quant_key = kFp8StaticTensorSym
|
|
||||||
|
|
||||||
self.fp8_linear = init_fp8_linear_kernel(
|
|
||||||
activation_quant_key=activation_quant_key,
|
|
||||||
weight_quant_key=weight_quant_key,
|
|
||||||
out_dtype=torch.get_default_dtype(),
|
|
||||||
module_name=self.__class__.__name__,
|
|
||||||
)
|
|
||||||
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
|
|
||||||
|
|
||||||
if self.block_quant and not self.use_marlin:
|
|
||||||
assert not self.act_q_static
|
assert not self.act_q_static
|
||||||
assert self.weight_block_size is not None
|
assert self.weight_block_size is not None
|
||||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
|
||||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
self.activation_quant_key = create_fp8_quant_key(
|
||||||
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
|
static=self.act_q_static,
|
||||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
group_shape=GroupShape(1, self.weight_block_size[0]),
|
||||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
|
||||||
use_deep_gemm=self.use_deep_gemm,
|
|
||||||
)
|
)
|
||||||
|
self.weight_quant_key = create_fp8_quant_key(
|
||||||
|
static=True, group_shape=GroupShape(*self.weight_block_size)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.weight_quant_key = kFp8StaticTensorSym
|
||||||
|
# Use per-token quantization for better perf if dynamic and cutlass
|
||||||
|
if self.act_q_static:
|
||||||
|
self.activation_quant_key = kFp8StaticTensorSym
|
||||||
|
elif cutlass_fp8_supported():
|
||||||
|
self.activation_quant_key = kFp8DynamicTokenSym
|
||||||
|
else:
|
||||||
|
self.activation_quant_key = kFp8DynamicTensorSym
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@@ -384,6 +372,17 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
set_weight_attrs(scale, {"scale_type": "input_scale"})
|
set_weight_attrs(scale, {"scale_type": "input_scale"})
|
||||||
layer.register_parameter("input_scale", scale)
|
layer.register_parameter("input_scale", scale)
|
||||||
|
|
||||||
|
self.fp8_linear = init_fp8_linear_kernel(
|
||||||
|
activation_quant_key=self.activation_quant_key,
|
||||||
|
weight_quant_key=self.weight_quant_key,
|
||||||
|
weight_shape=layer.weight.shape,
|
||||||
|
input_dtype=self.input_dtype,
|
||||||
|
out_dtype=self.out_dtype,
|
||||||
|
module_name=self.__class__.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
# Only Marlin kernels support `marlin_input_dtype`; guard to avoid
|
# Only Marlin kernels support `marlin_input_dtype`; guard to avoid
|
||||||
@@ -398,13 +397,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert not self.act_q_static
|
assert not self.act_q_static
|
||||||
|
|
||||||
weight, weight_scale_inv = process_fp8_weight_block_strategy(
|
self.fp8_linear.process_weights_after_loading(layer)
|
||||||
layer.weight, layer.weight_scale_inv
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update layer with new values
|
|
||||||
replace_parameter(layer, "weight", weight.data)
|
|
||||||
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
|
|
||||||
|
|
||||||
# If checkpoint not serialized fp8, quantize the weights.
|
# If checkpoint not serialized fp8, quantize the weights.
|
||||||
else:
|
else:
|
||||||
@@ -435,9 +428,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
else:
|
else:
|
||||||
layer.input_scale = None
|
layer.input_scale = None
|
||||||
|
|
||||||
if self.block_quant and self.use_deep_gemm:
|
|
||||||
maybe_post_process_fp8_weight_block(layer)
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -449,12 +439,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
if envs.VLLM_BATCH_INVARIANT:
|
if envs.VLLM_BATCH_INVARIANT:
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert self.weight_block_size is not None
|
assert self.weight_block_size is not None
|
||||||
return self.w8a8_block_fp8_linear.apply(
|
return self.fp8_linear.apply_weights(
|
||||||
input=x,
|
layer,
|
||||||
weight=layer.weight,
|
x,
|
||||||
weight_scale=layer.weight_scale_inv,
|
bias,
|
||||||
input_scale=layer.input_scale,
|
|
||||||
bias=bias,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# per-tensor/channel: dequant to BF16 and run GEMM
|
# per-tensor/channel: dequant to BF16 and run GEMM
|
||||||
@@ -483,17 +471,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||||
|
|
||||||
if self.block_quant:
|
|
||||||
assert self.weight_block_size is not None
|
|
||||||
|
|
||||||
return self.w8a8_block_fp8_linear.apply(
|
|
||||||
input=x,
|
|
||||||
weight=layer.weight,
|
|
||||||
weight_scale=layer.weight_scale_inv,
|
|
||||||
input_scale=layer.input_scale,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||||
|
|
||||||
|
|
||||||
@@ -538,6 +515,24 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
|
|||||||
|
|
||||||
initialize_online_processing(layer)
|
initialize_online_processing(layer)
|
||||||
|
|
||||||
|
# TODO: remove this check once the following RFC is resolved.
|
||||||
|
# https://github.com/vllm-project/vllm/issues/33314
|
||||||
|
# This check is required because Mxfp8OnlineLinearMethod inherits from
|
||||||
|
# Fp8OnlineLinearMethod but only calls super().create_weights(), so we must
|
||||||
|
# skip the fp8_linear kernel creation.
|
||||||
|
if hasattr(self, "mxfp8_linear"):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.fp8_linear = init_fp8_linear_kernel(
|
||||||
|
activation_quant_key=self.activation_quant_key,
|
||||||
|
weight_quant_key=self.weight_quant_key,
|
||||||
|
weight_shape=layer.weight.shape,
|
||||||
|
input_dtype=self.input_dtype,
|
||||||
|
out_dtype=self.out_dtype,
|
||||||
|
module_name=self.__class__.__name__,
|
||||||
|
)
|
||||||
|
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import torch
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
|
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
|
||||||
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
||||||
@@ -56,7 +57,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
|||||||
swap_w13_to_w31,
|
swap_w13_to_w31,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
W8A8BlockFp8LinearOp,
|
|
||||||
process_fp8_input_tensor_strategy_moe,
|
process_fp8_input_tensor_strategy_moe,
|
||||||
process_fp8_weight_tensor_strategy_moe,
|
process_fp8_weight_tensor_strategy_moe,
|
||||||
)
|
)
|
||||||
@@ -78,6 +78,7 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape,
|
GroupShape,
|
||||||
|
create_fp8_quant_key,
|
||||||
is_layer_skipped,
|
is_layer_skipped,
|
||||||
kFp8DynamicTokenSym,
|
kFp8DynamicTokenSym,
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
@@ -86,7 +87,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
kNvfp4Static,
|
kNvfp4Static,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
cutlass_block_fp8_supported,
|
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parameter import (
|
from vllm.model_executor.parameter import (
|
||||||
@@ -450,12 +450,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.fp8_linear = init_fp8_linear_kernel(
|
self.out_dtype = torch.get_default_dtype()
|
||||||
activation_quant_key=kFp8StaticTensorSym,
|
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||||
weight_quant_key=kFp8StaticTensorSym,
|
|
||||||
out_dtype=torch.get_default_dtype(),
|
|
||||||
module_name=self.__class__.__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@@ -505,6 +501,15 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
|||||||
scale[:] = torch.finfo(torch.float32).min
|
scale[:] = torch.finfo(torch.float32).min
|
||||||
layer.register_parameter("input_scale", scale)
|
layer.register_parameter("input_scale", scale)
|
||||||
|
|
||||||
|
self.fp8_linear = init_fp8_linear_kernel(
|
||||||
|
activation_quant_key=kFp8StaticTensorSym,
|
||||||
|
weight_quant_key=kFp8StaticTensorSym,
|
||||||
|
weight_shape=layer.weight.shape,
|
||||||
|
input_dtype=self.input_dtype,
|
||||||
|
out_dtype=self.out_dtype,
|
||||||
|
module_name=self.__class__.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
max_w_scale = layer.weight_scale.max()
|
max_w_scale = layer.weight_scale.max()
|
||||||
@@ -536,12 +541,8 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.fp8_linear = init_fp8_linear_kernel(
|
self.out_dtype = torch.get_default_dtype()
|
||||||
activation_quant_key=kFp8DynamicTokenSym,
|
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||||
weight_quant_key=kFp8StaticTokenSym,
|
|
||||||
out_dtype=torch.get_default_dtype(),
|
|
||||||
module_name=self.__class__.__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@@ -587,6 +588,15 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
|
|||||||
weight_scale[:] = torch.finfo(torch.float32).min
|
weight_scale[:] = torch.finfo(torch.float32).min
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
self.fp8_linear = init_fp8_linear_kernel(
|
||||||
|
activation_quant_key=kFp8DynamicTokenSym,
|
||||||
|
weight_quant_key=kFp8StaticTokenSym,
|
||||||
|
weight_shape=layer.weight.shape,
|
||||||
|
input_dtype=self.input_dtype,
|
||||||
|
out_dtype=self.out_dtype,
|
||||||
|
module_name=self.__class__.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
||||||
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
||||||
@@ -616,12 +626,16 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
block_n, block_k = self._WEIGHT_BLOCK_SIZE
|
block_n, block_k = self._WEIGHT_BLOCK_SIZE
|
||||||
self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
|
self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
|
||||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
|
||||||
weight_group_shape=GroupShape(block_n, block_k),
|
self.activation_quant_key = create_fp8_quant_key(
|
||||||
act_quant_group_shape=GroupShape(1, block_k),
|
static=False, group_shape=GroupShape(1, block_k)
|
||||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
|
|
||||||
use_aiter_and_is_supported=False,
|
|
||||||
)
|
)
|
||||||
|
self.weight_quant_key = create_fp8_quant_key(
|
||||||
|
static=True, group_shape=GroupShape(block_n, block_k)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_dtype = torch.get_default_dtype()
|
||||||
|
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@@ -688,8 +702,17 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
|
|||||||
weight_scale[:] = torch.finfo(torch.float32).min
|
weight_scale[:] = torch.finfo(torch.float32).min
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
self.w8a8_block_fp8_linear = init_fp8_linear_kernel(
|
||||||
|
activation_quant_key=self.activation_quant_key,
|
||||||
|
weight_quant_key=self.weight_quant_key,
|
||||||
|
weight_shape=layer.weight.shape,
|
||||||
|
input_dtype=self.input_dtype,
|
||||||
|
out_dtype=self.out_dtype,
|
||||||
|
module_name=self.__class__.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
# Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
|
# Keep weight in [out, in] layout for Fp8BlockScaledMMLinearKernel.
|
||||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||||
|
|
||||||
scale = layer.weight_scale
|
scale = layer.weight_scale
|
||||||
@@ -713,13 +736,7 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.w8a8_block_fp8_linear.apply(
|
return self.w8a8_block_fp8_linear.apply_weights(layer, x, bias)
|
||||||
input=x,
|
|
||||||
weight=layer.weight,
|
|
||||||
weight_scale=layer.weight_scale,
|
|
||||||
input_scale=None,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
|
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
|
||||||
from vllm.model_executor.layers.fused_moe import (
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
FusedMoEMethodBase,
|
FusedMoEMethodBase,
|
||||||
@@ -28,13 +28,9 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
|||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
LinearMethodBase,
|
LinearMethodBase,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|
||||||
W8A8BlockFp8LinearOp,
|
|
||||||
maybe_post_process_fp8_weight_block,
|
|
||||||
process_fp8_weight_block_strategy,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape,
|
GroupShape,
|
||||||
|
create_fp8_quant_key,
|
||||||
kFp8Dynamic128Sym,
|
kFp8Dynamic128Sym,
|
||||||
kFp8DynamicTensorSym,
|
kFp8DynamicTensorSym,
|
||||||
kFp8DynamicTokenSym,
|
kFp8DynamicTokenSym,
|
||||||
@@ -42,7 +38,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
cutlass_block_fp8_supported,
|
|
||||||
cutlass_fp8_supported,
|
cutlass_fp8_supported,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.reload.layerwise import (
|
from vllm.model_executor.model_loader.reload.layerwise import (
|
||||||
@@ -51,7 +46,7 @@ from vllm.model_executor.model_loader.reload.layerwise import (
|
|||||||
from vllm.model_executor.parameter import ModelWeightParameter
|
from vllm.model_executor.parameter import ModelWeightParameter
|
||||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.deep_gemm import is_deep_gemm_supported, per_block_cast_to_fp8
|
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Online FP8 Linear Methods
|
# Online FP8 Linear Methods
|
||||||
@@ -64,6 +59,10 @@ class _Fp8OnlineLinearBase(LinearMethodBase):
|
|||||||
|
|
||||||
uses_meta_device: bool = True
|
uses_meta_device: bool = True
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.out_dtype = torch.get_default_dtype()
|
||||||
|
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -103,18 +102,41 @@ class Fp8PerTensorOnlineLinearMethod(_Fp8OnlineLinearBase):
|
|||||||
Loads fp16/bf16 weights and quantizes them per-tensor during loading."""
|
Loads fp16/bf16 weights and quantizes them per-tensor during loading."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.out_dtype = torch.get_default_dtype()
|
super().__init__()
|
||||||
|
|
||||||
|
self.weight_quant_key = kFp8StaticTensorSym
|
||||||
# Use per-token quantization for better perf if dynamic and cutlass
|
# Use per-token quantization for better perf if dynamic and cutlass
|
||||||
if cutlass_fp8_supported():
|
if cutlass_fp8_supported():
|
||||||
activation_quant_key = kFp8DynamicTokenSym
|
self.activation_quant_key = kFp8DynamicTokenSym
|
||||||
else:
|
else:
|
||||||
activation_quant_key = kFp8DynamicTensorSym
|
self.activation_quant_key = kFp8DynamicTensorSym
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: list[int],
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
super().create_weights(
|
||||||
|
layer,
|
||||||
|
input_size_per_partition,
|
||||||
|
output_partition_sizes,
|
||||||
|
input_size,
|
||||||
|
output_size,
|
||||||
|
params_dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
)
|
||||||
|
|
||||||
self.fp8_linear = init_fp8_linear_kernel(
|
self.fp8_linear = init_fp8_linear_kernel(
|
||||||
activation_quant_key=activation_quant_key,
|
activation_quant_key=self.activation_quant_key,
|
||||||
weight_quant_key=kFp8StaticTensorSym,
|
weight_quant_key=self.weight_quant_key,
|
||||||
out_dtype=torch.get_default_dtype(),
|
weight_shape=layer.weight.shape,
|
||||||
|
input_dtype=self.input_dtype,
|
||||||
|
out_dtype=self.out_dtype,
|
||||||
module_name=self.__class__.__name__,
|
module_name=self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -166,19 +188,14 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
|
|||||||
Loads fp16/bf16 weights and quantizes them per-block during loading."""
|
Loads fp16/bf16 weights and quantizes them per-block during loading."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.out_dtype = torch.get_default_dtype()
|
super().__init__()
|
||||||
self.weight_block_size = [128, 128]
|
self.weight_block_size = [128, 128]
|
||||||
|
self.activation_quant_key = create_fp8_quant_key(
|
||||||
self.use_deep_gemm = is_deep_gemm_supported()
|
static=False,
|
||||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
group_shape=GroupShape(1, self.weight_block_size[0]),
|
||||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
)
|
||||||
|
self.weight_quant_key = create_fp8_quant_key(
|
||||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
static=True, group_shape=GroupShape(*self.weight_block_size)
|
||||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
|
||||||
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
|
|
||||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
|
||||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
|
||||||
use_deep_gemm=self.use_deep_gemm,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
@@ -202,6 +219,15 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
|
|||||||
)
|
)
|
||||||
layer.weight_block_size = self.weight_block_size
|
layer.weight_block_size = self.weight_block_size
|
||||||
|
|
||||||
|
self.fp8_linear = init_fp8_linear_kernel(
|
||||||
|
activation_quant_key=self.activation_quant_key,
|
||||||
|
weight_quant_key=self.weight_quant_key,
|
||||||
|
weight_shape=layer.weight.shape,
|
||||||
|
input_dtype=self.input_dtype,
|
||||||
|
out_dtype=self.out_dtype,
|
||||||
|
module_name=self.__class__.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||||
return
|
return
|
||||||
@@ -213,14 +239,10 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
|
|||||||
layer.weight, block_size=block_size, use_ue8m0=False
|
layer.weight, block_size=block_size, use_ue8m0=False
|
||||||
)
|
)
|
||||||
|
|
||||||
qweight, weight_scale_inv = process_fp8_weight_block_strategy(
|
|
||||||
qweight, weight_scale_inv
|
|
||||||
)
|
|
||||||
|
|
||||||
replace_parameter(layer, "weight", qweight.data)
|
replace_parameter(layer, "weight", qweight.data)
|
||||||
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
|
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
|
||||||
|
|
||||||
maybe_post_process_fp8_weight_block(layer)
|
self.fp8_linear.process_weights_after_loading(layer)
|
||||||
|
|
||||||
# Prevent duplicate processing (e.g., during weight reload)
|
# Prevent duplicate processing (e.g., during weight reload)
|
||||||
layer._already_called_process_weights_after_loading = True
|
layer._already_called_process_weights_after_loading = True
|
||||||
@@ -234,12 +256,10 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
|
|||||||
assert self.weight_block_size is not None
|
assert self.weight_block_size is not None
|
||||||
|
|
||||||
# Note: batch invariance already handled in the function below
|
# Note: batch invariance already handled in the function below
|
||||||
return self.w8a8_block_fp8_linear.apply(
|
return self.fp8_linear.apply_weights(
|
||||||
input=x,
|
layer,
|
||||||
weight=layer.weight,
|
x,
|
||||||
weight_scale=layer.weight_scale_inv,
|
bias,
|
||||||
input_scale=layer.input_scale,
|
|
||||||
bias=bias,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Any, cast
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.kernels.linear import (
|
from vllm.model_executor.kernels.linear import (
|
||||||
init_fp8_linear_kernel,
|
init_fp8_linear_kernel,
|
||||||
@@ -57,6 +58,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym
|
kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym
|
||||||
)
|
)
|
||||||
self.out_dtype = torch.get_default_dtype()
|
self.out_dtype = torch.get_default_dtype()
|
||||||
|
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@@ -175,7 +177,9 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
self.fp8_linear = init_fp8_linear_kernel(
|
self.fp8_linear = init_fp8_linear_kernel(
|
||||||
activation_quant_key=self.activation_quant_key,
|
activation_quant_key=self.activation_quant_key,
|
||||||
weight_quant_key=self.weight_quant_key,
|
weight_quant_key=self.weight_quant_key,
|
||||||
out_dtype=torch.get_default_dtype(),
|
weight_shape=layer.weight.shape,
|
||||||
|
input_dtype=self.input_dtype,
|
||||||
|
out_dtype=self.out_dtype,
|
||||||
module_name=self.__class__.__name__,
|
module_name=self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -12,15 +12,11 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape,
|
|
||||||
get_fp8_min_max,
|
get_fp8_min_max,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
|
||||||
all_close_1d,
|
all_close_1d,
|
||||||
per_tensor_dequantize,
|
per_tensor_dequantize,
|
||||||
)
|
)
|
||||||
@@ -29,22 +25,14 @@ from vllm.model_executor.parameter import (
|
|||||||
ChannelQuantScaleParameter,
|
ChannelQuantScaleParameter,
|
||||||
PerTensorScaleParameter,
|
PerTensorScaleParameter,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.deep_gemm import (
|
from vllm.utils.deep_gemm import (
|
||||||
fp8_gemm_nt,
|
|
||||||
get_tma_aligned_size,
|
get_tma_aligned_size,
|
||||||
is_deep_gemm_e8m0_used,
|
is_deep_gemm_e8m0_used,
|
||||||
is_deep_gemm_supported,
|
|
||||||
should_use_deepgemm_for_fp8_linear,
|
|
||||||
transform_sf_into_required_layout,
|
transform_sf_into_required_layout,
|
||||||
)
|
)
|
||||||
from vllm.utils.flashinfer import (
|
|
||||||
flashinfer_fp8_blockscale_gemm,
|
|
||||||
is_flashinfer_fp8_blockscale_gemm_supported,
|
|
||||||
should_use_flashinfer_for_blockscale_fp8_gemm,
|
|
||||||
)
|
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -56,153 +44,6 @@ def is_fp8(x: torch.dtype | torch.Tensor) -> bool:
|
|||||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||||
|
|
||||||
|
|
||||||
# We need to pass in the is_hopper flag as argument because the function
|
|
||||||
# current_platform.is_device_capability() is not supported by Torch compiler.
|
|
||||||
def cutlass_scaled_mm(
|
|
||||||
A: torch.Tensor,
|
|
||||||
B: torch.Tensor,
|
|
||||||
As: torch.Tensor,
|
|
||||||
Bs: torch.Tensor,
|
|
||||||
block_size: list[int],
|
|
||||||
output_dtype: torch.dtype = torch.float16,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return ops.cutlass_scaled_mm(
|
|
||||||
A,
|
|
||||||
B.T,
|
|
||||||
out_dtype=output_dtype,
|
|
||||||
scale_a=As,
|
|
||||||
scale_b=Bs.T,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO we should be able to change the type of block_size to GroupShape
|
|
||||||
# after we resolve GroupShape compilation issue
|
|
||||||
# https://github.com/vllm-project/vllm/issues/25270
|
|
||||||
def _w8a8_triton_block_scaled_mm_func(
|
|
||||||
qx: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
x_scale: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
block_size: list[int],
|
|
||||||
output_dtype: torch.dtype,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return w8a8_triton_block_scaled_mm(
|
|
||||||
qx, weight, x_scale, weight_scale, block_size, output_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _w8a8_triton_block_scaled_mm_fake(
|
|
||||||
qx: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
x_scale: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
block_size: list[int],
|
|
||||||
output_dtype: torch.dtype,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.empty(
|
|
||||||
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
"w8a8_triton_block_scaled_mm_func",
|
|
||||||
_w8a8_triton_block_scaled_mm_func,
|
|
||||||
fake_impl=_w8a8_triton_block_scaled_mm_fake,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _padded_cutlass(
|
|
||||||
qx: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
x_scale: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
block_size: list[int],
|
|
||||||
output_dtype: torch.dtype,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
pad_multiple = 4
|
|
||||||
dim = qx.shape[0]
|
|
||||||
padded = (
|
|
||||||
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
|
|
||||||
)
|
|
||||||
|
|
||||||
has_pad = padded > dim
|
|
||||||
|
|
||||||
if has_pad:
|
|
||||||
padded_shape = [padded, *qx.shape[1:]]
|
|
||||||
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
|
|
||||||
padded_qx[0 : qx.shape[0], ...].copy_(qx)
|
|
||||||
|
|
||||||
padded_x_scale_shape = [*x_scale.shape[1:], padded]
|
|
||||||
padded_x_scale = torch.ones(
|
|
||||||
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
|
|
||||||
).permute(-1, -2)
|
|
||||||
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
|
|
||||||
|
|
||||||
output = cutlass_scaled_mm(
|
|
||||||
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
|
|
||||||
)
|
|
||||||
return output[0 : qx.shape[0], ...]
|
|
||||||
else:
|
|
||||||
return cutlass_scaled_mm(
|
|
||||||
qx, weight, x_scale, weight_scale, block_size, output_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _padded_cutlass_fake(
|
|
||||||
qx: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
x_scale: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
block_size: list[int],
|
|
||||||
output_dtype: torch.dtype,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.empty(
|
|
||||||
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
"padded_cutlass",
|
|
||||||
_padded_cutlass,
|
|
||||||
fake_impl=_padded_cutlass_fake,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _fp8_gemm_nt_op(
|
|
||||||
q_input: torch.Tensor,
|
|
||||||
input_scale: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
output: torch.Tensor,
|
|
||||||
use_deep_gemm_e8m0: bool,
|
|
||||||
) -> None:
|
|
||||||
fp8_gemm_nt(
|
|
||||||
(q_input, input_scale),
|
|
||||||
(weight, weight_scale),
|
|
||||||
output,
|
|
||||||
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _fp8_gemm_nt_op_fake(
|
|
||||||
q_input: torch.Tensor,
|
|
||||||
input_scale: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
output: torch.Tensor,
|
|
||||||
use_deep_gemm_e8m0: bool,
|
|
||||||
) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
"fp8_gemm_nt_op",
|
|
||||||
_fp8_gemm_nt_op,
|
|
||||||
mutates_args=["output"],
|
|
||||||
fake_impl=_fp8_gemm_nt_op_fake,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _triton_per_token_group_quant_fp8_impl(
|
def _triton_per_token_group_quant_fp8_impl(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
@@ -236,362 +77,6 @@ direct_register_custom_op(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _flashinfer_fp8_blockscale_gemm_impl(
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
group_size: int,
|
|
||||||
use_deep_gemm_e8m0: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection.
|
|
||||||
|
|
||||||
This function switches between two optimized kernels based on the input batch size:
|
|
||||||
- For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization.
|
|
||||||
- For larger batches (M >= 32): Uses the official DeepGEMM kernel.
|
|
||||||
|
|
||||||
The conditional logic must use torch.cond() instead of a simple if-else statement
|
|
||||||
to maintain compatibility with torch.compile graph compilation.
|
|
||||||
|
|
||||||
This batch-size-dependent selection is essential for maintaining model accuracy.
|
|
||||||
Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1
|
|
||||||
when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accuracy
|
|
||||||
drop.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input: Input tensor of shape (batch_size, input_dim) in FP8 format
|
|
||||||
weight: Weight tensor of shape (output_dim, input_dim) in FP8 format
|
|
||||||
weight_scale: Scale factors for weight quantization (per-group)
|
|
||||||
group_size: Quantization group size for the weight tensor
|
|
||||||
use_deep_gemm_e8m0: Whether to use the E8M0 format in DeepGEMM quantization
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Output tensor of shape (batch_size, output_dim) in bfloat16 format
|
|
||||||
"""
|
|
||||||
|
|
||||||
def run_flashinfer_deepgemm_swapAB(
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return flashinfer_fp8_blockscale_gemm(
|
|
||||||
input=input,
|
|
||||||
weight=weight,
|
|
||||||
weight_scale=weight_scale,
|
|
||||||
out_dtype=torch.bfloat16,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_deepgemm(
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
q_input, input_scale = per_token_group_quant_fp8(
|
|
||||||
input,
|
|
||||||
group_size=group_size,
|
|
||||||
column_major_scales=True,
|
|
||||||
use_ue8m0=use_deep_gemm_e8m0,
|
|
||||||
)
|
|
||||||
output = torch.empty(
|
|
||||||
(q_input.shape[0], weight.shape[0]),
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
device=q_input.device,
|
|
||||||
)
|
|
||||||
fp8_gemm_nt(
|
|
||||||
(q_input, input_scale),
|
|
||||||
(weight, weight_scale),
|
|
||||||
output,
|
|
||||||
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
if envs.VLLM_BATCH_INVARIANT:
|
|
||||||
return run_deepgemm(input, weight, weight_scale)
|
|
||||||
|
|
||||||
condition = input.shape[0] < 32
|
|
||||||
|
|
||||||
# PyTorch's torch.compile cannot handle input-dependent control flow in standard
|
|
||||||
# Python conditionals. torch.cond() explicitly registers both code paths in the
|
|
||||||
# computation graph, allowing torch.compile to capture both branches.
|
|
||||||
# without torch.cond, the M < 32 condition won't be able to be captured by torch
|
|
||||||
# compile
|
|
||||||
return torch.cond(
|
|
||||||
condition,
|
|
||||||
run_flashinfer_deepgemm_swapAB,
|
|
||||||
run_deepgemm,
|
|
||||||
(input, weight, weight_scale),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _flashinfer_fp8_blockscale_gemm_fake(
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
group_size: int,
|
|
||||||
use_deep_gemm_e8m0: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Required fake/meta implementation for torch.compile graph tracing.
|
|
||||||
"""
|
|
||||||
return torch.empty(
|
|
||||||
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
"flashinfer_fp8_blockscale_gemm",
|
|
||||||
_flashinfer_fp8_blockscale_gemm_impl,
|
|
||||||
fake_impl=_flashinfer_fp8_blockscale_gemm_fake,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO fix ROCm->Triton custom path:
|
|
||||||
# https://github.com/vllm-project/vllm/issues/14397
|
|
||||||
class W8A8BlockFp8LinearOp:
|
|
||||||
"""
|
|
||||||
This class executes a Blocked FP8 linear layer using cutlass if supported
|
|
||||||
and torch.scaled_mm otherwise.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
weight_group_shape: GroupShape,
|
|
||||||
act_quant_group_shape: GroupShape,
|
|
||||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
|
||||||
use_aiter_and_is_supported: bool = False,
|
|
||||||
use_deep_gemm: bool | None = None,
|
|
||||||
):
|
|
||||||
self.weight_group_shape = weight_group_shape
|
|
||||||
self.act_quant_group_shape = act_quant_group_shape
|
|
||||||
if use_deep_gemm is not None:
|
|
||||||
self.is_deep_gemm_supported = use_deep_gemm
|
|
||||||
else:
|
|
||||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
|
||||||
self.is_hopper = current_platform.is_device_capability(90)
|
|
||||||
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
|
|
||||||
self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported()
|
|
||||||
|
|
||||||
# Get the correct blockscale mul and input quant operations.
|
|
||||||
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
|
|
||||||
# to use deepgemm because we don't know the shape of weights (and
|
|
||||||
# whether deepgemm supports it) at the init time.
|
|
||||||
self.w8a8_blockscale_op, self.input_quant_op = (
|
|
||||||
self._dispatch_w8a8_blockscale_op(
|
|
||||||
cutlass_block_fp8_supported, use_aiter_and_is_supported
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.deepgemm_input_quant_op = (
|
|
||||||
QuantFP8(
|
|
||||||
False,
|
|
||||||
self.act_quant_group_shape,
|
|
||||||
column_major_scales=True,
|
|
||||||
tma_aligned_scales=envs.VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES,
|
|
||||||
use_ue8m0=self.use_deep_gemm_e8m0,
|
|
||||||
)
|
|
||||||
if self.is_deep_gemm_supported
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
input_scale: torch.Tensor | None = None,
|
|
||||||
bias: torch.Tensor | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
assert input_scale is None
|
|
||||||
# View input as 2D matrix for fp8 methods
|
|
||||||
input_2d = input.view(-1, input.shape[-1])
|
|
||||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
|
||||||
output_dtype = input.dtype
|
|
||||||
|
|
||||||
if should_use_flashinfer_for_blockscale_fp8_gemm(
|
|
||||||
self.is_flashinfer_supported, output_dtype, input_2d, weight
|
|
||||||
) and should_use_deepgemm_for_fp8_linear(
|
|
||||||
output_dtype, weight, self.is_deep_gemm_supported
|
|
||||||
):
|
|
||||||
output = self._run_flashinfer(input_2d, weight, weight_scale)
|
|
||||||
|
|
||||||
elif should_use_deepgemm_for_fp8_linear(
|
|
||||||
output_dtype, weight, self.is_deep_gemm_supported
|
|
||||||
):
|
|
||||||
output = self._run_deepgemm(input_2d, weight, weight_scale)
|
|
||||||
else:
|
|
||||||
output = self.w8a8_blockscale_op(
|
|
||||||
input_2d, weight, weight_scale, input_scale
|
|
||||||
)
|
|
||||||
|
|
||||||
if bias is not None:
|
|
||||||
output = output + bias
|
|
||||||
return output.to(dtype=input.dtype).view(*output_shape)
|
|
||||||
|
|
||||||
def _run_deepgemm(
|
|
||||||
self,
|
|
||||||
input_2d: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
assert self.deepgemm_input_quant_op is not None
|
|
||||||
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
|
|
||||||
output = torch.empty(
|
|
||||||
(q_input.shape[0], weight.shape[0]),
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
device=q_input.device,
|
|
||||||
)
|
|
||||||
torch.ops.vllm.fp8_gemm_nt_op(
|
|
||||||
q_input, input_scale, weight, weight_scale, output, self.use_deep_gemm_e8m0
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def _run_cutlass(
|
|
||||||
self,
|
|
||||||
input_2d: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
input_scale: torch.Tensor | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
assert input_scale is None
|
|
||||||
assert self.input_quant_op is not None
|
|
||||||
q_input, input_scale = self.input_quant_op(input_2d)
|
|
||||||
if self.is_hopper:
|
|
||||||
return torch.ops.vllm.padded_cutlass(
|
|
||||||
q_input,
|
|
||||||
weight,
|
|
||||||
input_scale,
|
|
||||||
weight_scale,
|
|
||||||
list(self.weight_group_shape),
|
|
||||||
input_2d.dtype,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return cutlass_scaled_mm(
|
|
||||||
q_input,
|
|
||||||
weight,
|
|
||||||
input_scale,
|
|
||||||
weight_scale,
|
|
||||||
list(self.weight_group_shape),
|
|
||||||
input_2d.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_aiter(
|
|
||||||
self,
|
|
||||||
input_2d: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
input_scale: torch.Tensor | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
assert self.act_quant_group_shape == GroupShape(1, 128)
|
|
||||||
|
|
||||||
n, k = weight.shape
|
|
||||||
|
|
||||||
use_triton = (
|
|
||||||
not current_platform.is_fp8_fnuz()
|
|
||||||
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_triton:
|
|
||||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
|
|
||||||
else:
|
|
||||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
|
|
||||||
|
|
||||||
if input_scale is not None:
|
|
||||||
q_input = input_2d
|
|
||||||
else:
|
|
||||||
q_input, input_scale = self.input_quant_op(input_2d, use_triton=use_triton)
|
|
||||||
|
|
||||||
return gemm_a8w8_blockscale_op(
|
|
||||||
q_input,
|
|
||||||
weight,
|
|
||||||
input_scale,
|
|
||||||
weight_scale,
|
|
||||||
list(self.weight_group_shape),
|
|
||||||
output_dtype=input_2d.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_triton(
|
|
||||||
self,
|
|
||||||
input_2d: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
input_scale: torch.Tensor | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
assert input_scale is None
|
|
||||||
assert self.input_quant_op is not None
|
|
||||||
q_input, input_scale = self.input_quant_op(input_2d)
|
|
||||||
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
|
|
||||||
q_input,
|
|
||||||
weight,
|
|
||||||
input_scale,
|
|
||||||
weight_scale,
|
|
||||||
list(self.weight_group_shape),
|
|
||||||
input_2d.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_flashinfer(
|
|
||||||
self,
|
|
||||||
input_2d: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Run FlashInfer FP8 block-scale GEMM.
|
|
||||||
|
|
||||||
This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels
|
|
||||||
and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper).
|
|
||||||
"""
|
|
||||||
# Now call FlashInfer with BF16 input + FP8 weight, input will be
|
|
||||||
# quantized with FlashInfer kernel (W8A8)
|
|
||||||
output = torch.ops.vllm.flashinfer_fp8_blockscale_gemm(
|
|
||||||
input=input_2d, # BF16 input
|
|
||||||
weight=weight, # FP8 weight
|
|
||||||
weight_scale=weight_scale, # Weight scales
|
|
||||||
group_size=self.act_quant_group_shape.col,
|
|
||||||
use_deep_gemm_e8m0=self.use_deep_gemm_e8m0,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def _dispatch_w8a8_blockscale_op(
|
|
||||||
self,
|
|
||||||
use_cutlass: bool,
|
|
||||||
use_aiter_and_is_supported: bool,
|
|
||||||
) -> tuple[
|
|
||||||
Callable[
|
|
||||||
[
|
|
||||||
torch.Tensor,
|
|
||||||
torch.Tensor,
|
|
||||||
torch.Tensor,
|
|
||||||
torch.Tensor | None,
|
|
||||||
],
|
|
||||||
torch.Tensor,
|
|
||||||
],
|
|
||||||
QuantFP8,
|
|
||||||
]:
|
|
||||||
if use_cutlass:
|
|
||||||
return self._run_cutlass, (
|
|
||||||
QuantFP8(
|
|
||||||
False,
|
|
||||||
self.act_quant_group_shape,
|
|
||||||
column_major_scales=True,
|
|
||||||
use_ue8m0=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if use_aiter_and_is_supported:
|
|
||||||
return self._run_aiter, QuantFP8(
|
|
||||||
False,
|
|
||||||
self.act_quant_group_shape,
|
|
||||||
column_major_scales=False,
|
|
||||||
use_ue8m0=False,
|
|
||||||
)
|
|
||||||
return self._run_triton, (
|
|
||||||
QuantFP8(
|
|
||||||
False,
|
|
||||||
self.act_quant_group_shape,
|
|
||||||
column_major_scales=False,
|
|
||||||
use_ue8m0=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def input_to_float8(
|
def input_to_float8(
|
||||||
x: torch.Tensor, dtype: torch.dtype | None = None
|
x: torch.Tensor, dtype: torch.dtype | None = None
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
@@ -1612,34 +1097,6 @@ def process_fp8_weight_block_strategy(
|
|||||||
return weight, weight_scale
|
return weight, weight_scale
|
||||||
|
|
||||||
|
|
||||||
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
|
|
||||||
assert layer.weight_block_size is not None
|
|
||||||
|
|
||||||
from vllm.utils.deep_gemm import (
|
|
||||||
is_deep_gemm_e8m0_used,
|
|
||||||
should_use_deepgemm_for_fp8_linear,
|
|
||||||
)
|
|
||||||
|
|
||||||
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
|
|
||||||
# requantize the weight and input to the specific scale
|
|
||||||
# at the same time.
|
|
||||||
should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
|
|
||||||
layer.orig_dtype, layer.weight
|
|
||||||
)
|
|
||||||
if should_use_deepgemm:
|
|
||||||
scale_attr = (
|
|
||||||
"weight_scale_inv" if hasattr(layer, "weight_scale_inv") else "weight_scale"
|
|
||||||
)
|
|
||||||
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
|
|
||||||
wq=layer.weight.data,
|
|
||||||
ws=getattr(layer, scale_attr).data,
|
|
||||||
quant_block_shape=tuple(layer.weight_block_size),
|
|
||||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
|
||||||
)
|
|
||||||
replace_parameter(layer, "weight", dg_weight)
|
|
||||||
replace_parameter(layer, scale_attr, dg_weight_scale)
|
|
||||||
|
|
||||||
|
|
||||||
def process_fp8_weight_tensor_strategy_moe(
|
def process_fp8_weight_tensor_strategy_moe(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_scales: torch.Tensor,
|
weight_scales: torch.Tensor,
|
||||||
|
|||||||
@@ -171,6 +171,16 @@ kMxfp4StaticGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, True, GroupShape(1, 32))
|
|||||||
kMxfp4Static = QuantKey(FP4_DTYPE, scale=kMxfp4StaticGroupScale, symmetric=True)
|
kMxfp4Static = QuantKey(FP4_DTYPE, scale=kMxfp4StaticGroupScale, symmetric=True)
|
||||||
|
|
||||||
|
|
||||||
|
def create_fp8_quant_key(
|
||||||
|
static: bool,
|
||||||
|
group_shape: GroupShape,
|
||||||
|
symmetric: bool = True,
|
||||||
|
scale_dtype: torch.dtype = torch.float32,
|
||||||
|
) -> QuantKey:
|
||||||
|
scale_desc = ScaleDesc(scale_dtype, static, group_shape)
|
||||||
|
return QuantKey(FP8_DTYPE, scale_desc, symmetric=symmetric)
|
||||||
|
|
||||||
|
|
||||||
# Normalize the group_shape to the full extent for any dims that are -1
|
# Normalize the group_shape to the full extent for any dims that are -1
|
||||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||||
# -1 means full extent
|
# -1 means full extent
|
||||||
|
|||||||
@@ -413,7 +413,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
|||||||
|
|
||||||
def should_use_deepgemm_for_fp8_linear(
|
def should_use_deepgemm_for_fp8_linear(
|
||||||
output_dtype: torch.dtype,
|
output_dtype: torch.dtype,
|
||||||
weight: torch.Tensor,
|
weight_shape: tuple[int, int],
|
||||||
supports_deep_gemm: bool | None = None,
|
supports_deep_gemm: bool | None = None,
|
||||||
):
|
):
|
||||||
if supports_deep_gemm is None:
|
if supports_deep_gemm is None:
|
||||||
@@ -428,8 +428,8 @@ def should_use_deepgemm_for_fp8_linear(
|
|||||||
return (
|
return (
|
||||||
supports_deep_gemm
|
supports_deep_gemm
|
||||||
and output_dtype == torch.bfloat16
|
and output_dtype == torch.bfloat16
|
||||||
and weight.shape[0] % N_MULTIPLE == 0
|
and weight_shape[0] % N_MULTIPLE == 0
|
||||||
and weight.shape[1] % K_MULTIPLE == 0
|
and weight_shape[1] % K_MULTIPLE == 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -748,8 +748,9 @@ def is_flashinfer_fp8_blockscale_gemm_supported() -> bool:
|
|||||||
def should_use_flashinfer_for_blockscale_fp8_gemm(
|
def should_use_flashinfer_for_blockscale_fp8_gemm(
|
||||||
is_flashinfer_supported: bool,
|
is_flashinfer_supported: bool,
|
||||||
output_dtype: torch.dtype,
|
output_dtype: torch.dtype,
|
||||||
input: torch.Tensor,
|
input_dtype: torch.dtype,
|
||||||
weight: torch.Tensor,
|
weight_dtype: torch.dtype,
|
||||||
|
weight_shape: tuple[int, int],
|
||||||
):
|
):
|
||||||
if not is_flashinfer_supported:
|
if not is_flashinfer_supported:
|
||||||
return False
|
return False
|
||||||
@@ -760,15 +761,12 @@ def should_use_flashinfer_for_blockscale_fp8_gemm(
|
|||||||
N_MULTIPLE = 64
|
N_MULTIPLE = 64
|
||||||
K_MULTIPLE = 128
|
K_MULTIPLE = 128
|
||||||
|
|
||||||
weight_dtype = weight.dtype
|
|
||||||
input_dtype = input.dtype
|
|
||||||
|
|
||||||
should_use_flashinfer = (
|
should_use_flashinfer = (
|
||||||
output_dtype == torch.bfloat16
|
output_dtype == torch.bfloat16
|
||||||
and input_dtype == torch.bfloat16
|
and input_dtype == torch.bfloat16
|
||||||
and weight_dtype == torch.float8_e4m3fn
|
and weight_dtype == torch.float8_e4m3fn
|
||||||
and weight.shape[0] % N_MULTIPLE == 0
|
and weight_shape[0] % N_MULTIPLE == 0
|
||||||
and weight.shape[1] % K_MULTIPLE == 0
|
and weight_shape[1] % K_MULTIPLE == 0
|
||||||
)
|
)
|
||||||
|
|
||||||
return should_use_flashinfer
|
return should_use_flashinfer
|
||||||
|
|||||||
Reference in New Issue
Block a user