[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)
Signed-off-by: maral <maralbahari.98@gmail.com> Signed-off-by: Maral <maralbahari.98@gmail.com>
This commit is contained in:
@@ -9,11 +9,12 @@ os.environ["VLLM_USE_DEEP_GEMM"] = "0"
|
||||
import torch
|
||||
|
||||
from vllm.benchmarks.lib.utils import default_vllm_config
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
create_fp8_quant_key,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
@@ -70,11 +71,15 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
|
||||
weight_group_shape = GroupShape(block_n, block_k)
|
||||
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
|
||||
|
||||
linear_op = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=weight_group_shape,
|
||||
act_quant_group_shape=act_quant_group_shape,
|
||||
cutlass_block_fp8_supported=use_cutlass,
|
||||
use_aiter_and_is_supported=False,
|
||||
linear_op = init_fp8_linear_kernel(
|
||||
weight_quant_key=create_fp8_quant_key(
|
||||
static=True, group_shape=weight_group_shape
|
||||
),
|
||||
activation_quant_key=create_fp8_quant_key(
|
||||
static=False, group_shape=act_quant_group_shape
|
||||
),
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name="build_w8a8_block_fp8_runner",
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Reference in New Issue
Block a user