[Kernel] Optimize FP8 support for MoE kernel / Mixtral via static scales (#4343)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Philipp Moritz
2024-04-26 21:49:59 -07:00
committed by GitHub
parent 258a2c58d0
commit 12628d3c78
7 changed files with 95 additions and 18 deletions

View File

@@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel(
} // namespace vllm
void scaled_fp8_quant(
void static_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"scaled_fp8_quant_kernel",
[&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(),
input.data_ptr<scalar_t>(),
scale.data_ptr<float>(),
num_elems);
});
}
void dynamic_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]