[ROCm][Bugfix] Fix ROCm runtime failure due to missing symbol (#38750)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: tjtanaavllm <tunjian.tan@amd.com>
This commit is contained in:
Gregory Shtrasberg
2026-04-01 23:30:11 -05:00
committed by GitHub
parent 5a2d420c17
commit 3aab680e3e
2 changed files with 13 additions and 12 deletions

View File

@@ -142,11 +142,13 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
std::optional<torch::Tensor> residual,
int64_t group_size, bool is_scale_transposed);
#ifndef USE_ROCM
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
std::optional<torch::Tensor> scale_ub,
bool is_scale_transposed);
#endif
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,

View File

@@ -109,18 +109,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
// Fused SiLU+Mul + per-block quantization
ops.def(
"silu_and_mul_per_block_quant("
"Tensor! out, "
"Tensor input, "
"Tensor! scales, "
"int group_size, "
"Tensor? scale_ub=None, "
"bool is_scale_transposed=False) -> ()");
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
&silu_and_mul_per_block_quant);
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
@@ -244,6 +232,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantization ops
#ifndef USE_ROCM
// Fused SiLU+Mul + per-block quantization
ops.def(
"silu_and_mul_per_block_quant("
"Tensor! out, "
"Tensor input, "
"Tensor! scales, "
"int group_size, "
"Tensor? scale_ub=None, "
"bool is_scale_transposed=False) -> ()");
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
&silu_and_mul_per_block_quant);
// DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).
ops.def(
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");