dynamic distpatch of fp8 kernels (#14245)
Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
@@ -21,9 +21,9 @@
|
||||
namespace vllm {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void rms_norm_static_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float* __restrict__ scale, // [1]
|
||||
@@ -52,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
scaled_fp8_conversion<true>(out_norm, scale_inv);
|
||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,10 +60,10 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
Additional optimizations we can make in this case are
|
||||
packed and vectorized operations, which help with the
|
||||
memory latency bottleneck. */
|
||||
template <typename scalar_t, int width>
|
||||
template <typename scalar_t, int width, typename fp8_type>
|
||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
@@ -114,7 +114,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) {
|
||||
out[id * width + i] =
|
||||
scaled_fp8_conversion<true>(float(temp.data[i]), scale_inv);
|
||||
scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -122,10 +122,10 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
/* Generic fused_add_rms_norm_kernel
|
||||
The width field is not used here but necessary for other specializations.
|
||||
*/
|
||||
template <typename scalar_t, int width>
|
||||
template <typename scalar_t, int width, typename fp8_type>
|
||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
@@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
scaled_fp8_conversion<true>(out_norm, scale_inv);
|
||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), epsilon,
|
||||
num_tokens, hidden_size);
|
||||
});
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
|
||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
|
||||
epsilon, num_tokens, hidden_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
|
||||
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
|
||||
VLLM_DISPATCH_FP8_TYPES( \
|
||||
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
|
||||
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
|
||||
width, fp8_t> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
|
||||
epsilon, num_tokens, hidden_size); \
|
||||
}); \
|
||||
});
|
||||
|
||||
void fused_add_rms_norm_static_fp8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size],
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
|
||||
Reference in New Issue
Block a user