Files
vllm/csrc/libtorch_stable/dispatch_utils.h
2026-03-25 10:15:13 -07:00

61 lines
2.7 KiB
C++

/*
* Stable ABI compatible dispatch utilities for vLLM.
* Adapted from dispatch_utils.h to use PyTorch's header-only (THO_*) macros
* instead of the ATen (AT_*) macros.
*
* These macros use:
* - THO_DISPATCH_SWITCH instead of AT_DISPATCH_SWITCH
* - THO_DISPATCH_CASE instead of AT_DISPATCH_CASE
* - torch::headeronly::ScalarType instead of at::ScalarType
*
* Add more macros here as needed when migrating additional kernels.
*/
#pragma once
#include <torch/headeronly/core/Dispatch.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/util/Exception.h>
// Need a special dispatch case macro since we will nest the FP8 dispatch.
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
#define VLLM_STABLE_DISPATCH_FP8_CASE(enum_type, ...) \
THO_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
#define VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(...) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float, __VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_STABLE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
// FP8 type dispatch - ROCm uses FNUZ format, CUDA uses OCP format
#ifdef USE_ROCM
#define VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(...) \
VLLM_STABLE_DISPATCH_FP8_CASE( \
torch::headeronly::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
VLLM_STABLE_DISPATCH_FP8_CASE( \
torch::headeronly::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
#else
#define VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(...) \
VLLM_STABLE_DISPATCH_FP8_CASE( \
torch::headeronly::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#endif
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
// See VLLM_STABLE_DISPATCH_FP8_CASE above.
#define VLLM_STABLE_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
// Boolean dispatch
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
if (expr) { \
constexpr bool const_expr = true; \
__VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
__VA_ARGS__(); \
}