[Custom Ops] Add functional + out variant for scaled_fp4_quant (#34389)
Signed-off-by: tianrengao <terrygao87@gmail.com>
This commit is contained in:
12
csrc/ops.h
12
csrc/ops.h
@@ -295,10 +295,14 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
|
|
||||||
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
|
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
|
||||||
|
|
||||||
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
||||||
torch::Tensor& output_scale,
|
torch::Tensor const& input, torch::Tensor const& input_scale,
|
||||||
torch::Tensor const& input_scale,
|
bool is_sf_swizzled_layout);
|
||||||
bool is_sf_swizzled_layout);
|
|
||||||
|
void scaled_fp4_quant_out(torch::Tensor const& input,
|
||||||
|
torch::Tensor const& input_scale,
|
||||||
|
bool is_sf_swizzled_layout, torch::Tensor& output,
|
||||||
|
torch::Tensor& output_scale);
|
||||||
|
|
||||||
void scaled_fp4_experts_quant(
|
void scaled_fp4_experts_quant(
|
||||||
torch::Tensor& output, torch::Tensor& output_scale,
|
torch::Tensor& output, torch::Tensor& output_scale,
|
||||||
|
|||||||
@@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include "nvfp4_utils.cuh"
|
||||||
|
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||||
@@ -51,9 +53,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
|||||||
torch::Tensor const& output_scale_offset_by_experts);
|
torch::Tensor const& output_scale_offset_by_experts);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
void scaled_fp4_quant_out(torch::Tensor const& input,
|
||||||
torch::Tensor& output_sf, torch::Tensor const& input_sf,
|
torch::Tensor const& input_sf,
|
||||||
bool is_sf_swizzled_layout) {
|
bool is_sf_swizzled_layout, torch::Tensor& output,
|
||||||
|
torch::Tensor& output_sf) {
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
|
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
|
||||||
@@ -62,6 +65,34 @@ void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
|||||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
|
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
||||||
|
torch::Tensor const& input, torch::Tensor const& input_sf,
|
||||||
|
bool is_sf_swizzled_layout) {
|
||||||
|
int64_t n = input.size(-1);
|
||||||
|
int64_t m = input.numel() / n;
|
||||||
|
auto device = input.device();
|
||||||
|
|
||||||
|
// Two fp4 values packed into a uint8
|
||||||
|
auto output = torch::empty(
|
||||||
|
{m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8));
|
||||||
|
|
||||||
|
torch::Tensor output_sf;
|
||||||
|
if (is_sf_swizzled_layout) {
|
||||||
|
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
|
||||||
|
output_sf = torch::empty(
|
||||||
|
{sf_m, sf_n},
|
||||||
|
torch::TensorOptions().device(device).dtype(torch::kInt32));
|
||||||
|
} else {
|
||||||
|
output_sf = torch::empty(
|
||||||
|
{m, n / CVT_FP4_SF_VEC_SIZE},
|
||||||
|
torch::TensorOptions().device(device).dtype(torch::kUInt8));
|
||||||
|
}
|
||||||
|
|
||||||
|
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
|
||||||
|
output_sf);
|
||||||
|
return {output, output_sf};
|
||||||
|
}
|
||||||
|
|
||||||
void scaled_fp4_experts_quant(
|
void scaled_fp4_experts_quant(
|
||||||
torch::Tensor& output, torch::Tensor& output_scale,
|
torch::Tensor& output, torch::Tensor& output_scale,
|
||||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <cuda_fp8.h>
|
#include <cuda_fp8.h>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "../../cuda_vec_utils.cuh"
|
#include "../../cuda_vec_utils.cuh"
|
||||||
|
|
||||||
@@ -54,6 +55,18 @@ inline int computeEffectiveRows(int m) {
|
|||||||
return round_up(m, ROW_TILE);
|
return round_up(m, ROW_TILE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Compute the shape of the swizzled SF output tensor.
|
||||||
|
// Returns (rounded_m, rounded_n / 4) where:
|
||||||
|
// rounded_m = round_up(m, 128)
|
||||||
|
// rounded_n = round_up(n / CVT_FP4_SF_VEC_SIZE, 4)
|
||||||
|
inline std::pair<int64_t, int64_t> computeSwizzledSFShape(int64_t m,
|
||||||
|
int64_t n) {
|
||||||
|
int64_t rounded_m = round_up(m, static_cast<int64_t>(128));
|
||||||
|
int64_t scale_n = n / CVT_FP4_SF_VEC_SIZE;
|
||||||
|
int64_t rounded_n = round_up(scale_n, static_cast<int64_t>(4));
|
||||||
|
return {rounded_m, rounded_n / 4};
|
||||||
|
}
|
||||||
|
|
||||||
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
||||||
inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
|
inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
|
||||||
uint32_t val;
|
uint32_t val;
|
||||||
|
|||||||
@@ -564,10 +564,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
|
|
||||||
// Compute NVFP4 block quantized tensor.
|
// Compute NVFP4 block quantized tensor.
|
||||||
ops.def(
|
ops.def(
|
||||||
"scaled_fp4_quant(Tensor! output, Tensor input,"
|
"scaled_fp4_quant(Tensor input,"
|
||||||
" Tensor! output_scale, Tensor input_scale, bool "
|
" Tensor input_scale, bool "
|
||||||
"is_sf_swizzled_layout) -> ()");
|
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
|
||||||
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
|
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant_func);
|
||||||
|
|
||||||
|
// Out variant
|
||||||
|
// TODO: Add {at::Tag::out_variant} tag and update all call sites
|
||||||
|
// to use the functional variant once vLLM upgrades PyTorch.
|
||||||
|
// See pytorch/pytorch#176117.
|
||||||
|
ops.def(
|
||||||
|
"scaled_fp4_quant.out(Tensor input,"
|
||||||
|
" Tensor input_scale, bool "
|
||||||
|
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
|
||||||
|
"-> ()");
|
||||||
|
ops.impl("scaled_fp4_quant.out", torch::kCUDA, &scaled_fp4_quant_out);
|
||||||
|
|
||||||
// Compute NVFP4 experts quantization.
|
// Compute NVFP4 experts quantization.
|
||||||
ops.def(
|
ops.def(
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
|||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
return [
|
return [
|
||||||
torch.ops.vllm.all_reduce.default,
|
torch.ops.vllm.all_reduce.default,
|
||||||
torch.ops._C.scaled_fp4_quant.default,
|
torch.ops._C.scaled_fp4_quant.out,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -159,6 +159,52 @@ def test_quantize_to_fp4(
|
|||||||
torch.testing.assert_close(scale_ans, scale_ref)
|
torch.testing.assert_close(scale_ans, scale_ref)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"shape",
|
||||||
|
[(32, 4096), (128, 4096), (1, 64), (127, 1024), (256, 16384)],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_python_util_matches_cpp_allocation(
|
||||||
|
shape: tuple[int, int],
|
||||||
|
is_sf_swizzled_layout: bool,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Verify that the Python utility (create_fp4_output_tensors) allocates
|
||||||
|
tensors with the same shapes and dtypes as the C++ functional variant
|
||||||
|
(scaled_fp4_quant_func).
|
||||||
|
"""
|
||||||
|
from vllm._custom_ops import create_fp4_output_tensors
|
||||||
|
|
||||||
|
torch.set_default_device("cuda:0")
|
||||||
|
m, n = shape
|
||||||
|
input_tensor = torch.randn((m, n), dtype=torch.bfloat16)
|
||||||
|
input_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda:0")
|
||||||
|
|
||||||
|
# C++ functional variant allocates internally
|
||||||
|
cpp_out, cpp_scale = torch.ops._C.scaled_fp4_quant(
|
||||||
|
input_tensor, input_scale, is_sf_swizzled_layout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Python utility
|
||||||
|
py_out, py_scale = create_fp4_output_tensors(
|
||||||
|
m, n, torch.device("cuda:0"), is_sf_swizzled_layout
|
||||||
|
)
|
||||||
|
|
||||||
|
assert py_out.shape == cpp_out.shape, (
|
||||||
|
f"Output shape mismatch: Python {py_out.shape} vs C++ {cpp_out.shape}"
|
||||||
|
)
|
||||||
|
assert py_out.dtype == cpp_out.dtype, (
|
||||||
|
f"Output dtype mismatch: Python {py_out.dtype} vs C++ {cpp_out.dtype}"
|
||||||
|
)
|
||||||
|
assert py_scale.shape == cpp_scale.shape, (
|
||||||
|
f"Scale shape mismatch: Python {py_scale.shape} vs C++ {cpp_scale.shape}"
|
||||||
|
)
|
||||||
|
assert py_scale.dtype == cpp_scale.dtype, (
|
||||||
|
f"Scale dtype mismatch: Python {py_scale.dtype} vs C++ {cpp_scale.dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pad_shape", PAD_SHAPES)
|
@pytest.mark.parametrize("pad_shape", PAD_SHAPES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
|
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
|
||||||
|
|||||||
@@ -29,6 +29,81 @@ else:
|
|||||||
from torch.library import impl_abstract as register_fake
|
from torch.library import impl_abstract as register_fake
|
||||||
|
|
||||||
|
|
||||||
|
# scaled_fp4_quant functional + out variant for torch.compile buffer management
|
||||||
|
|
||||||
|
|
||||||
|
def create_fp4_scale_tensor(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
device: torch.device,
|
||||||
|
is_sf_swizzled_layout: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Allocate the output scale tensor for scaled_fp4_quant.
|
||||||
|
|
||||||
|
When is_sf_swizzled_layout=True, we use rounded values to store the
|
||||||
|
swizzled scales. Due to the requirement of the Tensor Core, the minimum
|
||||||
|
tile is 128x4 for the scales. So, we first pad the scales to multiples
|
||||||
|
of 128 (rows) and 4 (cols). Then, the scales (in float8_e4m3fn) are
|
||||||
|
packed into an int32 for every 4 values. More:
|
||||||
|
https://docs.nvidia.com/cuda/parallel-thread-execution/
|
||||||
|
#tcgen05-mma-scale-factor-b-layout-4x
|
||||||
|
"""
|
||||||
|
from vllm.utils.math_utils import round_up
|
||||||
|
|
||||||
|
block_size = 16
|
||||||
|
if is_sf_swizzled_layout:
|
||||||
|
rounded_m = round_up(m, 128)
|
||||||
|
scale_n = n // block_size
|
||||||
|
rounded_n = round_up(scale_n, 4)
|
||||||
|
return torch.empty(
|
||||||
|
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return torch.empty((m, n // block_size), device=device, dtype=torch.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def create_fp4_output_tensors(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
device: torch.device,
|
||||||
|
is_sf_swizzled_layout: bool,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Allocate both output tensors for scaled_fp4_quant:
|
||||||
|
(quantized_output, output_scale).
|
||||||
|
|
||||||
|
Must match the C++ scaled_fp4_quant_func allocation exactly.
|
||||||
|
"""
|
||||||
|
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||||
|
output_scale = create_fp4_scale_tensor(m, n, device, is_sf_swizzled_layout)
|
||||||
|
return output, output_scale
|
||||||
|
|
||||||
|
|
||||||
|
if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||||
|
|
||||||
|
@register_fake("_C::scaled_fp4_quant")
|
||||||
|
def _scaled_fp4_quant_fake(
|
||||||
|
input: torch.Tensor,
|
||||||
|
input_scale: torch.Tensor,
|
||||||
|
is_sf_swizzled_layout: bool,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
n = input.shape[-1]
|
||||||
|
m = input.numel() // n
|
||||||
|
return create_fp4_output_tensors(m, n, input.device, is_sf_swizzled_layout)
|
||||||
|
|
||||||
|
@register_fake("_C::scaled_fp4_quant.out")
|
||||||
|
def _scaled_fp4_quant_out_fake(
|
||||||
|
input: torch.Tensor,
|
||||||
|
input_scale: torch.Tensor,
|
||||||
|
is_sf_swizzled_layout: bool,
|
||||||
|
*,
|
||||||
|
output: torch.Tensor,
|
||||||
|
output_scale: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# page attention ops
|
# page attention ops
|
||||||
def paged_attention_v1(
|
def paged_attention_v1(
|
||||||
out: torch.Tensor,
|
out: torch.Tensor,
|
||||||
@@ -1644,7 +1719,6 @@ def scaled_fp4_quant(
|
|||||||
input = input.reshape(other_dims, input.shape[-1])
|
input = input.reshape(other_dims, input.shape[-1])
|
||||||
m, n = input.shape
|
m, n = input.shape
|
||||||
block_size = 16
|
block_size = 16
|
||||||
device = input.device
|
|
||||||
|
|
||||||
assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
|
assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
|
||||||
assert input.dtype in (torch.float16, torch.bfloat16), (
|
assert input.dtype in (torch.float16, torch.bfloat16), (
|
||||||
@@ -1658,26 +1732,16 @@ def scaled_fp4_quant(
|
|||||||
input, input_global_scale
|
input, input_global_scale
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Two fp4 values will be packed into an uint8.
|
# Pre-allocate and call .out variant (same behavior as old in-place API)
|
||||||
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
output, output_scale = create_fp4_output_tensors(
|
||||||
if is_sf_swizzled_layout:
|
m, n, input.device, is_sf_swizzled_layout
|
||||||
# We use the rounded values to store the swizzled values. Due to the
|
)
|
||||||
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
|
torch.ops._C.scaled_fp4_quant.out(
|
||||||
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
|
input,
|
||||||
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
|
input_global_scale,
|
||||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
|
is_sf_swizzled_layout,
|
||||||
round_up = lambda x, y: (x + y - 1) // y * y
|
output=output,
|
||||||
rounded_m = round_up(m, 128)
|
output_scale=output_scale,
|
||||||
scale_n = n // block_size
|
|
||||||
rounded_n = round_up(scale_n, 4)
|
|
||||||
output_scale = torch.empty(
|
|
||||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output_scale = torch.empty((m, n // 16), device=device, dtype=torch.uint8)
|
|
||||||
|
|
||||||
torch.ops._C.scaled_fp4_quant(
|
|
||||||
output, input, output_scale, input_global_scale, is_sf_swizzled_layout
|
|
||||||
)
|
)
|
||||||
|
|
||||||
output_scale = output_scale.view(torch.float8_e4m3fn)
|
output_scale = output_scale.view(torch.float8_e4m3fn)
|
||||||
|
|||||||
@@ -148,11 +148,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
|||||||
result_silu_mul = self.silu_and_mul_matcher(input)
|
result_silu_mul = self.silu_and_mul_matcher(input)
|
||||||
at = auto_functionalized(
|
at = auto_functionalized(
|
||||||
self.QUANT_OP,
|
self.QUANT_OP,
|
||||||
output=result,
|
|
||||||
input=result_silu_mul,
|
input=result_silu_mul,
|
||||||
output_scale=output_scale,
|
|
||||||
input_scale=scale,
|
input_scale=scale,
|
||||||
is_sf_swizzled_layout=True,
|
is_sf_swizzled_layout=True,
|
||||||
|
output=result,
|
||||||
|
output_scale=output_scale,
|
||||||
)
|
)
|
||||||
return at[1], at[2]
|
return at[1], at[2]
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ if find_spec("flashinfer"):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if hasattr(torch.ops._C, "scaled_fp4_quant"):
|
if hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.out
|
||||||
|
|
||||||
# Max size of the input tensor per world size per device capability
|
# Max size of the input tensor per world size per device capability
|
||||||
# to use flashinfer fused allreduce
|
# to use flashinfer fused allreduce
|
||||||
@@ -562,11 +562,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
|||||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||||
quant_out_tuple = auto_functionalized(
|
quant_out_tuple = auto_functionalized(
|
||||||
STATIC_FP4_QUANT_OP,
|
STATIC_FP4_QUANT_OP,
|
||||||
output=quant_result,
|
|
||||||
input=rms,
|
input=rms,
|
||||||
output_scale=output_scale,
|
|
||||||
input_scale=input_global_scale,
|
input_scale=input_global_scale,
|
||||||
is_sf_swizzled_layout=True,
|
is_sf_swizzled_layout=True,
|
||||||
|
output=quant_result,
|
||||||
|
output_scale=output_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# quant_out, allreduce_output, output_scale
|
# quant_out, allreduce_output, output_scale
|
||||||
@@ -660,11 +660,11 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
|||||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||||
quant_out_tuple = auto_functionalized(
|
quant_out_tuple = auto_functionalized(
|
||||||
STATIC_FP4_QUANT_OP,
|
STATIC_FP4_QUANT_OP,
|
||||||
output=quant_result,
|
|
||||||
input=rms,
|
input=rms,
|
||||||
output_scale=output_scale,
|
|
||||||
input_scale=input_global_scale,
|
input_scale=input_global_scale,
|
||||||
is_sf_swizzled_layout=True,
|
is_sf_swizzled_layout=True,
|
||||||
|
output=quant_result,
|
||||||
|
output_scale=output_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# quant_out, allreduce_output, output_scale
|
# quant_out, allreduce_output, output_scale
|
||||||
|
|||||||
@@ -250,11 +250,11 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
|||||||
)
|
)
|
||||||
at2 = auto_functionalized(
|
at2 = auto_functionalized(
|
||||||
self.QUANT_OP,
|
self.QUANT_OP,
|
||||||
output=output_quant,
|
|
||||||
input=attn_out_view,
|
input=attn_out_view,
|
||||||
output_scale=output_scale,
|
|
||||||
input_scale=input_scale,
|
input_scale=input_scale,
|
||||||
is_sf_swizzled_layout=True,
|
is_sf_swizzled_layout=True,
|
||||||
|
output=output_quant,
|
||||||
|
output_scale=output_scale,
|
||||||
)
|
)
|
||||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||||
return at2[1], output_scale_view
|
return at2[1], output_scale_view
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
|||||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||||
}
|
}
|
||||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
|
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||||
|
|||||||
Reference in New Issue
Block a user