[ROCm][Bugfix] Fix ROCm runtime failure due to missing symbol (#38750)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: tjtanaavllm <tunjian.tan@amd.com>
This commit is contained in:
committed by
GitHub
parent
5a2d420c17
commit
3aab680e3e
@@ -142,11 +142,13 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
std::optional<torch::Tensor> residual,
|
||||
int64_t group_size, bool is_scale_transposed);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scales, int64_t group_size,
|
||||
std::optional<torch::Tensor> scale_ub,
|
||||
bool is_scale_transposed);
|
||||
#endif
|
||||
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
std::optional<torch::Tensor> key, int64_t head_size,
|
||||
|
||||
@@ -109,18 +109,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
|
||||
|
||||
// Fused SiLU+Mul + per-block quantization
|
||||
ops.def(
|
||||
"silu_and_mul_per_block_quant("
|
||||
"Tensor! out, "
|
||||
"Tensor input, "
|
||||
"Tensor! scales, "
|
||||
"int group_size, "
|
||||
"Tensor? scale_ub=None, "
|
||||
"bool is_scale_transposed=False) -> ()");
|
||||
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
|
||||
&silu_and_mul_per_block_quant);
|
||||
|
||||
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
|
||||
|
||||
@@ -244,6 +232,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
// Fused SiLU+Mul + per-block quantization
|
||||
ops.def(
|
||||
"silu_and_mul_per_block_quant("
|
||||
"Tensor! out, "
|
||||
"Tensor input, "
|
||||
"Tensor! scales, "
|
||||
"int group_size, "
|
||||
"Tensor? scale_ub=None, "
|
||||
"bool is_scale_transposed=False) -> ()");
|
||||
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
|
||||
&silu_and_mul_per_block_quant);
|
||||
// DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).
|
||||
ops.def(
|
||||
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||
|
||||
Reference in New Issue
Block a user