Kernel-override Determinism [1/n] (#25603)

Signed-off-by: Bram Wasti <bwasti@meta.com>
This commit is contained in:
Bram Wasti
2025-09-26 19:59:09 -04:00
committed by GitHub
parent 4778b42660
commit dc48ba0c75
8 changed files with 890 additions and 4 deletions

View File

@@ -9,6 +9,7 @@
#include "quantization/fp8/common.cuh"
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
@@ -240,7 +241,9 @@ void fused_add_rms_norm_static_fp8_quant(
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) {
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
!batch_invariant_launch) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);