[2/n] Migrate per_token_group_quant to torch stable ABI (#36058)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
60
csrc/libtorch_stable/dispatch_utils.h
Normal file
60
csrc/libtorch_stable/dispatch_utils.h
Normal file
@@ -0,0 +1,60 @@
|
||||
/*
|
||||
* Stable ABI compatible dispatch utilities for vLLM.
|
||||
* Adapted from dispatch_utils.h to use PyTorch's header-only (THO_*) macros
|
||||
* instead of the ATen (AT_*) macros.
|
||||
*
|
||||
* These macros use:
|
||||
* - THO_DISPATCH_SWITCH instead of AT_DISPATCH_SWITCH
|
||||
* - THO_DISPATCH_CASE instead of AT_DISPATCH_CASE
|
||||
* - torch::headeronly::ScalarType instead of at::ScalarType
|
||||
*
|
||||
* Add more macros here as needed when migrating additional kernels.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/core/Dispatch.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
|
||||
// Need a special dispatch case macro since we will nest the FP8 dispatch.
|
||||
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
|
||||
#define VLLM_STABLE_DISPATCH_FP8_CASE(enum_type, ...) \
|
||||
THO_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
|
||||
|
||||
#define VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float, __VA_ARGS__) \
|
||||
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \
|
||||
THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_STABLE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
THO_DISPATCH_SWITCH(TYPE, NAME, \
|
||||
VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
// FP8 type dispatch - ROCm uses FNUZ format, CUDA uses OCP format
|
||||
#ifdef USE_ROCM
|
||||
#define VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(...) \
|
||||
VLLM_STABLE_DISPATCH_FP8_CASE( \
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
||||
VLLM_STABLE_DISPATCH_FP8_CASE( \
|
||||
torch::headeronly::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
|
||||
#else
|
||||
#define VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(...) \
|
||||
VLLM_STABLE_DISPATCH_FP8_CASE( \
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
|
||||
// See VLLM_STABLE_DISPATCH_FP8_CASE above.
|
||||
#define VLLM_STABLE_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
|
||||
THO_DISPATCH_SWITCH(TYPE, NAME, \
|
||||
VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
|
||||
|
||||
// Boolean dispatch
|
||||
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
|
||||
if (expr) { \
|
||||
constexpr bool const_expr = true; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool const_expr = false; \
|
||||
__VA_ARGS__(); \
|
||||
}
|
||||
Reference in New Issue
Block a user